@@ -99,21 +99,39 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
9999
100100
101101def get_parameter_dtype (parameter : torch .nn .Module ) -> torch .dtype :
102- try :
103- return next (parameter .parameters ()).dtype
104- except StopIteration :
105- try :
106- return next (parameter .buffers ()).dtype
107- except StopIteration :
108- # For torch.nn.DataParallel compatibility in PyTorch 1.5
109-
110- def find_tensor_attributes (module : torch .nn .Module ) -> List [Tuple [str , Tensor ]]:
111- tuples = [(k , v ) for k , v in module .__dict__ .items () if torch .is_tensor (v )]
112- return tuples
113-
114- gen = parameter ._named_members (get_members_fn = find_tensor_attributes )
115- first_tuple = next (gen )
116- return first_tuple [1 ].dtype
102+ """
103+ Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
104+ """
105+ last_dtype = None
106+ for param in parameter .parameters ():
107+ last_dtype = param .dtype
108+ if param .is_floating_point ():
109+ return param .dtype
110+
111+ for buffer in parameter .buffers ():
112+ last_dtype = buffer .dtype
113+ if buffer .is_floating_point ():
114+ return buffer .dtype
115+
116+ if last_dtype is not None :
117+ # if no floating dtype was found return whatever the first dtype is
118+ return last_dtype
119+
120+ # For nn.DataParallel compatibility in PyTorch > 1.5
121+ def find_tensor_attributes (module : nn .Module ) -> List [Tuple [str , Tensor ]]:
122+ tuples = [(k , v ) for k , v in module .__dict__ .items () if torch .is_tensor (v )]
123+ return tuples
124+
125+ gen = parameter ._named_members (get_members_fn = find_tensor_attributes )
126+ last_tuple = None
127+ for tuple in gen :
128+ last_tuple = tuple
129+ if tuple [1 ].is_floating_point ():
130+ return tuple [1 ].dtype
131+
132+ if last_tuple is not None :
133+ # fallback to the last dtype
134+ return last_tuple [1 ].dtype
117135
118136
119137class ModelMixin (torch .nn .Module , PushToHubMixin ):
0 commit comments