Skip to content

Commit 90cf1a2

Browse files
committed
pre-commit lint fixes
1 parent 4de419c commit 90cf1a2

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/haliax/nn/linear.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,12 +159,14 @@ def _out_first(self):
159159

160160
def to_state_dict(self, prefix: Optional[str] = None) -> StateDict:
161161
# weight can be None for certain filtering things like LoRA
162-
scaled = dataclasses.replace(self, weight=self.weight * self.reparam.active_scale if self.weight is not None else None)
162+
scaled = dataclasses.replace(
163+
self, weight=self.weight * self.reparam.active_scale if self.weight is not None else None
164+
)
163165
return default_eqx_module_to_state_dict(scaled, prefix)
164166

165167
def from_state_dict(self: Mod, state_dict: StateDict, prefix: Optional[str] = None) -> Mod:
166168
unscaled = default_eqx_module_from_state_dict(self, state_dict, prefix)
167-
if unscaled.weight is not None:
169+
if unscaled.weight is not None:
168170
unscaled = dataclasses.replace(unscaled, weight=unscaled.weight / self.reparam.active_scale)
169171
return unscaled
170172

0 commit comments

Comments
 (0)