Skip to content

Commit 3c6c84a

Browse files
committed
add ignore
1 parent c5bde43 commit 3c6c84a

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/lightning/fabric/utilities/spike.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,16 +126,16 @@ def _handle_spike(self, fabric: "Fabric", batch_idx: int) -> None:
126126
raise TrainingSpikeException(batch_idx=batch_idx)
127127

128128
def _check_atol(self, val_a: Union[float, torch.Tensor], val_b: Union[float, torch.Tensor]) -> bool:
129-
return (self.atol is None) or bool(abs(val_a - val_b) >= abs(self.atol))
129+
return (self.atol is None) or bool(abs(val_a - val_b) >= abs(self.atol)) # type: ignore
130130

131131
def _check_rtol(self, val_a: Union[float, torch.Tensor], val_b: Union[float, torch.Tensor]) -> bool:
132-
return (self.rtol is None) or bool(abs(val_a - val_b) >= abs(self.rtol * val_b))
132+
return (self.rtol is None) or bool(abs(val_a - val_b) >= abs(self.rtol * val_b)) # type: ignore
133133

134134
def _is_better(self, diff_val: torch.Tensor) -> bool:
135135
if self.mode == "min":
136-
return bool((diff_val <= 0.0).all())
136+
return bool((diff_val <= 0.0).all()) # type: ignore[operator]
137137
if self.mode == "max":
138-
return bool((diff_val >= 0).all())
138+
return bool((diff_val >= 0).all()) # type: ignore[operator]
139139

140140
raise ValueError(f"Invalid mode. Has to be min or max, found {self.mode}")
141141

0 commit comments

Comments
 (0)