Skip to content

Commit 0b4fcec

Browse files
author
Hossein Kavianihamedani
committed
Fix flake8 linting: Add spaces around arithmetic operators
1 parent 4345efc commit 0b4fcec

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

tests/integration_tests/test_titan_fwd_vs_hf_fwd.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def compare_logits(
213213
hf_val = hf_logits_cpu[pos].item()
214214
diff_val = abs_diff[pos].item()
215215
print(
216-
f" {i+1}. Position {pos}: titan={titan_val:.6f}, hf={hf_val:.6f}, diff={diff_val:.6f}"
216+
f" {i + 1}. Position {pos}: titan={titan_val:.6f}, hf={hf_val:.6f}, diff={diff_val:.6f}"
217217
)
218218

219219
return metrics
@@ -242,12 +242,12 @@ def compare_probabilities(
242242
zip(titan_top_k.values, titan_top_k.indices)
243243
):
244244
token = tokenizer.decode([token_id.item()])
245-
print(f" {i+1}. '{token}' (id={token_id.item()}): {prob.item():.6f}")
245+
print(f" {i + 1}. '{token}' (id={token_id.item()}): {prob.item():.6f}")
246246

247247
print("\nHugging Face Top-K:")
248248
for i, (prob, token_id) in enumerate(zip(hf_top_k.values, hf_top_k.indices)):
249249
token = tokenizer.decode([token_id.item()])
250-
print(f" {i+1}. '{token}' (id={token_id.item()}): {prob.item():.6f}")
250+
print(f" {i + 1}. '{token}' (id={token_id.item()}): {prob.item():.6f}")
251251

252252
# Calculate overlap in top-k predictions
253253
titan_top_tokens = set(titan_top_k.indices.tolist())

tests/unit_tests/datasets/test_hf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,10 +231,10 @@ def test_shuffling_behavior(self, dataset_factory, small_dataset_file):
231231
# But should contain the same set of IDs
232232
assert set(first_epoch_ids) == set(
233233
range(SMALL_DATASET_SIZE)
234-
), f"First epoch samples should be (0-{SMALL_DATASET_SIZE-1}), got {first_epoch_ids}"
234+
), f"First epoch samples should be (0-{SMALL_DATASET_SIZE - 1}), got {first_epoch_ids}"
235235
assert set(second_epoch_ids) == set(
236236
range(SMALL_DATASET_SIZE)
237-
), f"Second epoch samples should be (0-{SMALL_DATASET_SIZE-1}), got {second_epoch_ids}"
237+
), f"Second epoch samples should be (0-{SMALL_DATASET_SIZE - 1}), got {second_epoch_ids}"
238238

239239
def test_epoch_tracking(self, dataset_factory, small_dataset_file):
240240
"""Test that epoch number is correctly tracked across dataset restarts."""

0 commit comments

Comments
 (0)