@@ -504,10 +504,7 @@ def __init__(
504504
505505 self .in_features = in_features
506506 self .out_features = out_features
507- if bias :
508- self .bias = torch .nn .Parameter (torch .empty (out_features , ** self .factory_kwargs ))
509- else :
510- self .register_parameter ("bias" , None )
507+ self ._has_bias = bias
511508
512509 self .tensor_class = None
513510 self ._full_precision_mm = MixedPrecisionOps ._full_precision_mm
@@ -536,6 +533,10 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
536533 self .weight = torch .nn .Parameter (weight .to (device = device , dtype = dtype ), requires_grad = False )
537534 if dtype != MixedPrecisionOps ._compute_dtype :
538535 self .comfy_cast_weights = True
536+ if self ._has_bias :
537+ self .bias = torch .nn .Parameter (torch .empty (self .out_features , device = device , dtype = dtype ))
538+ else :
539+ self .register_parameter ("bias" , None )
539540 else :
540541 self .quant_format = layer_conf .get ("format" , None )
541542 if not self ._full_precision_mm :
@@ -565,6 +566,11 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
565566 requires_grad = False
566567 )
567568
569+ if self ._has_bias :
570+ self .bias = torch .nn .Parameter (torch .empty (self .out_features , device = device , dtype = MixedPrecisionOps ._compute_dtype ))
571+ else :
572+ self .register_parameter ("bias" , None )
573+
568574 for param_name in qconfig ["parameters" ]:
569575 param_key = f"{ prefix } { param_name } "
570576 _v = state_dict .pop (param_key , None )
0 commit comments