Skip to content

Commit c5a47a1

Browse files
Fix bias dtype issue in mixed ops. (#11293)
1 parent 908fd7d commit c5a47a1

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

comfy/ops.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -504,10 +504,7 @@ def __init__(
504504

505505
self.in_features = in_features
506506
self.out_features = out_features
507-
if bias:
508-
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
509-
else:
510-
self.register_parameter("bias", None)
507+
self._has_bias = bias
511508

512509
self.tensor_class = None
513510
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
@@ -536,6 +533,10 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
536533
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False)
537534
if dtype != MixedPrecisionOps._compute_dtype:
538535
self.comfy_cast_weights = True
536+
if self._has_bias:
537+
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=dtype))
538+
else:
539+
self.register_parameter("bias", None)
539540
else:
540541
self.quant_format = layer_conf.get("format", None)
541542
if not self._full_precision_mm:
@@ -565,6 +566,11 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
565566
requires_grad=False
566567
)
567568

569+
if self._has_bias:
570+
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=MixedPrecisionOps._compute_dtype))
571+
else:
572+
self.register_parameter("bias", None)
573+
568574
for param_name in qconfig["parameters"]:
569575
param_key = f"{prefix}{param_name}"
570576
_v = state_dict.pop(param_key, None)

0 commit comments

Comments
 (0)