@@ -77,15 +77,9 @@ def _assert_layer_fsdp_instance(self) -> None:
7777 assert isinstance (self .layer , FullyShardedDataParallel )
7878 assert isinstance (self .trainer .strategy .precision_plugin , FSDPPrecision )
7979
80- if self .trainer .precision == "16-mixed" :
81- param_dtype = torch .float32
82- reduce_dtype = buffer_dtype = torch .float16
83- elif self .trainer .precision == "bf16-mixed" :
84- param_dtype = torch .float32
85- reduce_dtype = buffer_dtype = torch .bfloat16
86- elif self .trainer .precision == "16-true" :
80+ if self .trainer .precision in ("16-true" , "16-mixed" ):
8781 param_dtype = reduce_dtype = buffer_dtype = torch .float16
88- elif self .trainer .precision == "bf16-true" :
82+ elif self .trainer .precision in ( "bf16-true" , "bf16-mixed" ) :
8983 param_dtype = reduce_dtype = buffer_dtype = torch .bfloat16
9084 else :
9185 raise ValueError (f"Unknown precision { self .trainer .precision } " )
@@ -138,15 +132,9 @@ def _assert_layer_fsdp_instance(self) -> None:
138132 assert isinstance (self .layer , torch .nn .Sequential )
139133 assert isinstance (self .trainer .strategy .precision_plugin , FSDPPrecision )
140134
141- if self .trainer .precision == "16-mixed" :
142- param_dtype = torch .float32
143- reduce_dtype = buffer_dtype = torch .float16
144- elif self .trainer .precision == "bf16-mixed" :
145- param_dtype = torch .float32
146- reduce_dtype = buffer_dtype = torch .bfloat16
147- elif self .trainer .precision == "16-true" :
135+ if self .trainer .precision in ("16-true" , "16-mixed" ):
148136 param_dtype = reduce_dtype = buffer_dtype = torch .float16
149- elif self .trainer .precision == "bf16-true" :
137+ elif self .trainer .precision in ( "bf16-true" , "bf16-mixed" ) :
150138 param_dtype = reduce_dtype = buffer_dtype = torch .bfloat16
151139 else :
152140 raise ValueError (f"Unknown precision { self .trainer .precision } " )
0 commit comments