Skip to content

format everything with ruff#68

Merged
RobotSail merged 2 commits intoRed-Hat-AI-Innovation-Team:mainfrom
RobotSail:ruff-format
Feb 5, 2026
Merged

format everything with ruff#68
RobotSail merged 2 commits intoRed-Hat-AI-Innovation-Team:mainfrom
RobotSail:ruff-format

Conversation

@RobotSail
Copy link
Collaborator

@RobotSail RobotSail commented Feb 5, 2026

Summary by CodeRabbit

  • New Features

    • Training CLI/types: added seed and weight_decay options; some training args made more explicit/strict and dtype values accept string forms.
    • New/reshuffled exports: EpochSampler added; create_osft_model_class exposed for test/util use.
  • Bug Fixes

    • Improved robustness of weights & biases (wandb) integration and early availability checks.
  • Chores

    • Widespread code-style/typing modernization, formatting cleanup, test updates, and linter configuration added.

@coderabbitai
Copy link

coderabbitai bot commented Feb 5, 2026

📝 Walkthrough

Walkthrough

Applies widespread formatting and typing modernizations across the codebase (quote normalization, line-wrapping, PEP 604 unions), plus a targeted change: adding step tracking to orthogonality metrics and related reporting. Most edits are stylistic; a few files update public type annotations and data-class fields.

Changes

Cohort / File(s) Summary
Orthogonality metrics
regression_tests/test_osft_orthogonalization.py, tests/test_utils/orthogonality.py, tests/test_utils/__init__.py
Added step: int to OrthogonalityMetrics, updated tracker storage/return types (list[dict]), adjusted update signatures/uses to include step, and enhanced reporting/formatting.
Public typing & data-class updates
src/mini_trainer/training_types.py, src/mini_trainer/setup_model_for_training.py, src/mini_trainer/mlflow_wrapper.py, src/mini_trainer/wandb_wrapper.py
Converted many Optional/Dict annotations to PEP 604 unions and builtin generics (e.g., `str
FSDP / research helpers
research_scratch/fsdp1_dummy_script.py, research_scratch/fsdp1_wrapper.py, research_scratch/sequence_length_experiment.py
Added/adjusted FSDP auto-wrap, sharding_strategy and mixed-precision args; signature/type-annotation modernizations and minor formatting.
Formatting-heavy modules
src/mini_trainer/batch_metrics.py, src/mini_trainer/batch_packer.py, src/mini_trainer/gpt_oss_utils.py, src/mini_trainer/none_reduction_losses.py, src/mini_trainer/sampler.py, src/mini_trainer/utils.py, src/mini_trainer/api_train.py, src/mini_trainer/train.py, src/mini_trainer/fsdp2_lazy_init.py, src/mini_trainer/async_structured_logger.py, src/mini_trainer/osft_utils.py, src/mini_trainer/__init__.py, src/mini_trainer/mini_trainer/*
Widespread style changes: quote normalization, whitespace/line-wrapping, minor logging/message reflows. One public API typing change in gpt_oss_utils.convert_dequantized_to_quantized_format_correct (Dict -> dict). No runtime logic changes except typing.
Regression & benchmark tests
regression_tests/benchmark_batching.py, regression_tests/test_gpt_oss_conversion_accuracy.py, regression_tests/test_osft_fidelity_script.py
Reflowed imports/parameters and restructured internal data assembly (minibatch/seq-length handling) and standardized return-key quoting; behavior claimed equivalent but internal load-assembly refactor warrants attention.
Tests (format + small behavior tests)
tests/** (many files; see summary) tests/conftest.py, tests/gpu_tests/*, tests/test_api_train.py, tests/test_async_logger.py, tests/test_batch_lengths_to_minibatches.py, tests/test_dtype_conversion.py, tests/test_integration_small_models.py, tests/test_osft*.py, tests/test_training_components.py, tests/test_training_loop.py, tests/test_utils/*, tests/test_data_loader*.py, tests/test_pretraining_dataset.py, tests/test_model_initialization.py, ...
Extensive test-formatting: quote/style normalization, decorator reflows, some added imports in tests. Notable test behavior changes: dtype tests now pass dtype params as strings; a small assertion added checking a timeout kwarg in init_process_group. Mostly non-functional.
Scripts & utilities
scripts/convert_to_pretrain.py, scripts/process_data.py, test-vector-projection.py, tutorials/tensor_paralleism_getting_started.py
Minor CLI/IO formatting and exception handling tweaks, tokenization call reflows, and consistent quoting; no functional API changes.
Config/tooling
pyproject.toml
Added comprehensive Ruff config and lint/format settings (line-length, quote style, per-file ignores).

Sequence Diagram(s)

(omitted)

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

  • Add MLflow logging support #66: Overlaps mlflow and async logging changes touching src/mini_trainer/mlflow_wrapper.py, async_structured_logger.py, and training logging integration.
  • Add OSFT orthogonalization test #50: Related to orthogonality utilities and tests; overlaps with tests/test_utils/orthogonality.py and orthogonality test updates.
  • Adds Wandb + validation loss #38: Related wandb/async logging changes and ImportError fallback patterns (affects wandb_wrapper.py, logger integration).

Suggested reviewers

  • NikhilNayak-debug
  • Maxusmusti

Poem

🐰 I hopped through lines, fixed quotes and space,

step numbers tucked in metrics' place.
Whitespace trimmed, the tests now gleam,
types modernized like a tidy dream.
A rabbit's nibble—code neat as lace.

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'format everything with ruff' directly and accurately describes the main objective of the PR, which is to apply Ruff formatting across the entire codebase.
Docstring Coverage ✅ Passed Docstring coverage is 85.15% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@codecov
Copy link

codecov bot commented Feb 5, 2026

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@tests/gpu_tests/test_mixed_precision.py`:
- Around line 87-95: The test patches
transformers.loss.loss_utils.fixed_cross_entropy via patch_target_module with
hf_fixed_cross_entropy_none_reduction but never restores the original, causing
cross-test leakage; modify the test to save the original value of
transformers.loss.loss_utils.fixed_cross_entropy before calling
patch_target_module (or use patch_target_module's return if it provides the
original), and restore that original in a finally block after the test block so
the original fixed_cross_entropy is reinstated even on errors.
🧹 Nitpick comments (3)
research_scratch/fsdp1_wrapper.py (1)

42-42: Consider adding validation for _no_split_modules.

While not related to the formatting changes, accessing model._no_split_modules[0] could raise AttributeError or IndexError if the attribute doesn't exist or the list is empty. Consider adding a check before accessing the first element.

🛡️ Optional defensive check
+    if not hasattr(model, '_no_split_modules') or not model._no_split_modules:
+        raise ValueError("Model must have non-empty _no_split_modules attribute")
     block_name = model._no_split_modules[0]
tests/gpu_tests/test_mixed_precision.py (1)

124-126: Normalize the unreduced loss by token count, not just batch size.

With per-token losses, dividing only by batch_size scales gradients with sequence length. Using a token-count normalization better mirrors training semantics.

♻️ Suggested adjustment
-            loss = loss.float().sum() / batch_size
+            loss = loss.float().sum() / labels.numel()

Based on learnings: In the mini_trainer codebase, loss functions are patched to return unreduced per-token losses, so outputs.loss must be summed and normalized.

tests/test_osft_dtype_functionality.py (1)

9-14: Unused import: auto_generate_target_osft_config

The newly added import auto_generate_target_osft_config does not appear to be used anywhere in this test file. If this was intentionally added for future use, consider removing it until needed to keep imports clean.

🔧 Proposed fix
 from mini_trainer.osft_utils import (
     create_svd_dict,
     reconstruct_weight_matrix,
     create_osft_model_class,
-    auto_generate_target_osft_config,
 )

Comment on lines 87 to 95
# Patch loss function for none reduction
from mini_trainer.none_reduction_losses import hf_fixed_cross_entropy_none_reduction
from mini_trainer.none_reduction_losses import (
hf_fixed_cross_entropy_none_reduction,
)

patch_target_module(
"transformers.loss.loss_utils.fixed_cross_entropy",
hf_fixed_cross_entropy_none_reduction,
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Restore the patched loss function to avoid cross-test leakage.

patch_target_module mutates a global in transformers. Without restoring it in finally, later tests in the same process can inherit the patched behavior.

🔧 Proposed fix (restore in finally)
+            from transformers.loss.loss_utils import fixed_cross_entropy as _orig_fixed_ce
             patch_target_module(
                 "transformers.loss.loss_utils.fixed_cross_entropy",
                 hf_fixed_cross_entropy_none_reduction,
             )
@@
         finally:
             # Cleanup distributed environment
+            patch_target_module(
+                "transformers.loss.loss_utils.fixed_cross_entropy",
+                _orig_fixed_ce,
+            )
             if dist.is_initialized():
                 dist.destroy_process_group()

Also applies to: 226-229

🤖 Prompt for AI Agents
In `@tests/gpu_tests/test_mixed_precision.py` around lines 87 - 95, The test
patches transformers.loss.loss_utils.fixed_cross_entropy via patch_target_module
with hf_fixed_cross_entropy_none_reduction but never restores the original,
causing cross-test leakage; modify the test to save the original value of
transformers.loss.loss_utils.fixed_cross_entropy before calling
patch_target_module (or use patch_target_module's return if it provides the
original), and restore that original in a finally block after the test block so
the original fixed_cross_entropy is reinstated even on errors.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tests/test_dtype_conversion.py (1)

227-247: ⚠️ Potential issue | 🟡 Minor

Patch destroy_distributed_environment to avoid real teardown side effects.
If main() gets far enough to reach teardown, this test may call the real destroy function. Keep it mocked for consistency with the other integration tests.

Suggested fix
 `@patch`("torch.distributed.get_world_size", return_value=1)
 `@patch`("mini_trainer.train.get_node_rank", return_value=0)
+@patch("mini_trainer.train.destroy_distributed_environment")
 `@patch`("mini_trainer.train.init_distributed_environment")
 `@patch`("torch.distributed.get_rank")
-def test_main_invalid_dtype_raises_error(self, mock_get_rank, mock_init_dist, mock_get_node_rank, mock_world_size):
+def test_main_invalid_dtype_raises_error(
+    self,
+    mock_get_rank,
+    mock_init_dist,
+    mock_destroy_dist,
+    mock_get_node_rank,
+    mock_world_size,
+):
research_scratch/fsdp1_wrapper.py (1)

18-29: ⚠️ Potential issue | 🟡 Minor

Fix the return type annotation to reflect class objects.

get_module_class_from_name returns a module class, not a module instance. The current annotation (torch.nn.Module | None) is misleading for type checkers and incorrect for the actual return values (e.g., model.__class__).

✅ Suggested fix
-def get_module_class_from_name(model: torch.nn.Module, name: str) -> torch.nn.Module | None:
+def get_module_class_from_name(model: torch.nn.Module, name: str) -> type[torch.nn.Module] | None:
     modules_children = list(model.children())

     if model.__class__.__name__ == name:
         return model.__class__
     elif len(modules_children) == 0:
-        return
+        return None
     else:
         for child_module in modules_children:
             module_class = get_module_class_from_name(child_module, name)
             if module_class is not None:
                 return module_class
+    return None
🤖 Fix all issues with AI agents
In `@pyproject.toml`:
- Around line 160-167: The per-file-ignore globs under the
tool.ruff.lint.per-file-ignores table are not recursive; update each
directory-level pattern to use recursive globs so nested files are matched
(e.g., change "tests/*" to "tests/**"); apply the same update for other
directory patterns such as "regression_tests/*" -> "regression_tests/**",
"research_scratch/*" -> "research_scratch/**", "tutorials/*" -> "tutorials/**"
and "scripts/*" -> "scripts/**" while leaving single-file patterns (like
"__init__.py" and "test-*.py") unchanged so the existing ignore lists (S101,
F841, F401, F811, E722, etc.) continue to apply.

In `@regression_tests/test_gpt_oss_conversion_accuracy.py`:
- Around line 418-429: The file defines two functions named test_inference
causing the parameterless one to shadow the real helper and trigger TypeError;
rename the import-validation function (the parameterless test that imports
convert_dequantized_to_quantized_format_correct and
update_config_for_quantized_format) to a unique name such as
test_inference_imports or test_gpt_oss_utils_imports, leave its body unchanged,
and ensure no other code references the old name so test_inference(model,
tokenizer, stage_name) remains the callable inference helper used by
test_conversion_accuracy and other tests.

In `@regression_tests/test_osft_orthogonalization.py`:
- Line 50: The type annotation for self.metrics incorrectly declares the inner
values as OrthogonalityMetrics; update the annotation to reflect that the stored
values are plain nested dicts (not instances). Replace the declaration of
self.metrics with a nested-dict type (for example use dict[str, dict[str,
dict[str, float]]] or a more permissive dict[str, dict[str, Any]]/Mapping if
values vary) so the annotation matches the actual stored structure and reference
the existing symbol self.metrics and the OrthogonalityMetrics type when making
the change.

In `@tests/test_utils/orthogonality.py`:
- Line 32: The type annotation for self.metrics is incorrect: it declares
dict[str, dict[str, OrthogonalityMetrics]] but the update() method stores plain
dicts (string keys/values) rather than OrthogonalityMetrics instances; fix by
either changing the annotation to match stored data (e.g., dict[str, dict[str,
str]] or dict[str, dict[str, Any]] as appropriate) or alter update() to
construct and store OrthogonalityMetrics dataclass instances instead of plain
dicts; locate the self.metrics declaration and the update() method in
tests/test_utils/orthogonality.py and make the annotation and stored value types
consistent (referencing self.metrics and update()).
🧹 Nitpick comments (2)
scripts/process_data.py (1)

28-29: Consider using a more specific exception type.

The bare except: clause catches all exceptions including SystemExit and KeyboardInterrupt, which is generally discouraged per PEP 8. Consider using except Exception: for consistency with line 98.

💡 Optional: Use specific exception type
         message_token = tokenizer.encode("<|message|>", add_special_tokens=False)
-    except:
+    except Exception:
         return False
regression_tests/test_osft_orthogonalization.py (1)

35-201: Consider importing from tests.test_utils.orthogonality to reduce duplication.

OrthogonalityMetrics, OrthogonalityTracker, compute_angle_differences, check_gradient_orthogonality, and check_parameter_orthogonality are duplicated between this file and tests/test_utils/orthogonality.py. Since the test utilities module now exists (and was extracted from this file per its docstring), this file could import from it instead.

Suggested consolidation
+from tests.test_utils.orthogonality import (
+    OrthogonalityMetrics,
+    OrthogonalityTracker,
+    check_gradient_orthogonality,
+    check_parameter_orthogonality,
+    compute_angle_differences,
+)
-
-@dataclass
-class OrthogonalityMetrics:
-    ...
-
-class OrthogonalityTracker:
-    ...
-
-def compute_angle_differences(...):
-    ...
-
-def check_gradient_orthogonality(...):
-    ...
-
-def check_parameter_orthogonality(...):
-    ...

Note: The regression test's OrthogonalityTracker has additional print statements in update() for real-time violation logging. You could either:

  1. Add that logging to the shared implementation, or
  2. Subclass OrthogonalityTracker in the regression test to add the print behavior.

Comment on lines +160 to +167
[tool.ruff.lint.per-file-ignores]
"tests/*" = ["S101", "F841"] # allow assert and unused vars in tests
"__init__.py" = ["F401"] # allow unused imports in __init__
"regression_tests/*" = ["F401", "F811", "F841"] # allow unused and redefinitions in benchmark scripts
"research_scratch/*" = ["F401", "F841"] # allow unused in research scripts
"tutorials/*" = ["F401", "F841"] # allow unused in tutorials
"scripts/*" = ["E722"] # allow bare except in scripts
"test-*.py" = ["F841"] # allow unused vars in standalone test files
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check for nested Python files in target directories
for d in tests regression_tests research_scratch tutorials scripts; do
  if [ -d "$d" ]; then
    echo "== $d =="
    fd -t f '\.py$' "$d" | head -20
  fi
done

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 1202


🌐 Web query:

Ruff per-file-ignores glob pattern matching semantics "tests/*" vs "tests/**"

💡 Result:

Ruff’s per-file-ignores patterns use Rust’s globset syntax, matched against paths relative to the project root. (docs.astral.sh)

tests/*

  • * matches within a single path component (it won’t cross /). (docs.rs)
  • So tests/* matches files directly inside tests/, e.g.:
    • tests/test_api.py
    • tests/unit/test_api.py ❌ (because of the extra /unit/)

tests/**

  • In globset, a pattern ending in /** matches all sub-entries under that directory, recursively (e.g., foo/** matches foo/a and foo/a/b, but not foo itself). (docs.rs)
  • So tests/** matches:
    • tests/test_api.py
    • tests/unit/test_api.py
    • but not the directory path tests itself (Ruff is matching files anyway). (docs.rs)

Citations:


Update per-file-ignores to use recursive globs for nested test files.

Nested test files under tests/gpu_tests/ are not covered by the current tests/* pattern, which only matches one directory level. Ruff's globset syntax requires tests/** for recursive matching. Update all per-file-ignore patterns to use recursive globs:

Required fix
[tool.ruff.lint.per-file-ignores]
-"tests/*" = ["S101", "F841"]  # allow assert and unused vars in tests
-"regression_tests/*" = ["F401", "F811", "F841"]  # allow unused and redefinitions in benchmark scripts
-"research_scratch/*" = ["F401", "F841"]  # allow unused in research scripts
-"tutorials/*" = ["F401", "F841"]  # allow unused in tutorials
-"scripts/*" = ["E722"]  # allow bare except in scripts
+"tests/**" = ["S101", "F841"]  # allow assert and unused vars in tests
+"regression_tests/**" = ["F401", "F811", "F841"]  # allow unused and redefinitions in benchmark scripts
+"research_scratch/**" = ["F401", "F841"]  # allow unused in research scripts
+"tutorials/**" = ["F401", "F841"]  # allow unused in tutorials
+"scripts/**" = ["E722"]  # allow bare except in scripts
🤖 Prompt for AI Agents
In `@pyproject.toml` around lines 160 - 167, The per-file-ignore globs under the
tool.ruff.lint.per-file-ignores table are not recursive; update each
directory-level pattern to use recursive globs so nested files are matched
(e.g., change "tests/*" to "tests/**"); apply the same update for other
directory patterns such as "regression_tests/*" -> "regression_tests/**",
"research_scratch/*" -> "research_scratch/**", "tutorials/*" -> "tutorials/**"
and "scripts/*" -> "scripts/**" while leaving single-file patterns (like
"__init__.py" and "test-*.py") unchanged so the existing ignore lists (S101,
F841, F401, F811, E722, etc.) continue to apply.

Comment on lines 418 to 429
def test_inference():
"""Simple test to ensure GPT-OSS utilities can be imported."""
try:
from mini_trainer.gpt_oss_utils import convert_dequantized_to_quantized_format_correct
from mini_trainer.gpt_oss_utils import update_config_for_quantized_format
from mini_trainer.gpt_oss_utils import (
convert_dequantized_to_quantized_format_correct,
update_config_for_quantized_format,
)

assert callable(convert_dequantized_to_quantized_format_correct)
assert callable(update_config_for_quantized_format)
except ImportError as e:
pytest.fail(f"Failed to import GPT-OSS utilities: {e}")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Function name collision will cause TypeError at runtime.

There are two functions named test_inference:

  1. Lines 383-416: test_inference(model, tokenizer, stage_name) - the actual inference test helper
  2. Lines 418-429: test_inference() - an import validation test with no parameters

The second definition shadows the first. When test_conversion_accuracy() calls test_inference(dequant_model, tokenizer, "DEQUANTIZED") at line 137 or test_inference(converted_model, tokenizer, "MINI_TRAINER_CONVERTED") at line 362, Python will use the parameterless version and raise:

TypeError: test_inference() takes 0 positional arguments but 3 were given
🐛 Proposed fix: rename the import validation test
-def test_inference():
-    """Simple test to ensure GPT-OSS utilities can be imported."""
+def test_gpt_oss_imports():
+    """Simple test to ensure GPT-OSS utilities can be imported."""
     try:
         from mini_trainer.gpt_oss_utils import (
             convert_dequantized_to_quantized_format_correct,
             update_config_for_quantized_format,
         )

         assert callable(convert_dequantized_to_quantized_format_correct)
         assert callable(update_config_for_quantized_format)
     except ImportError as e:
         pytest.fail(f"Failed to import GPT-OSS utilities: {e}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def test_inference():
"""Simple test to ensure GPT-OSS utilities can be imported."""
try:
from mini_trainer.gpt_oss_utils import convert_dequantized_to_quantized_format_correct
from mini_trainer.gpt_oss_utils import update_config_for_quantized_format
from mini_trainer.gpt_oss_utils import (
convert_dequantized_to_quantized_format_correct,
update_config_for_quantized_format,
)
assert callable(convert_dequantized_to_quantized_format_correct)
assert callable(update_config_for_quantized_format)
except ImportError as e:
pytest.fail(f"Failed to import GPT-OSS utilities: {e}")
def test_gpt_oss_imports():
"""Simple test to ensure GPT-OSS utilities can be imported."""
try:
from mini_trainer.gpt_oss_utils import (
convert_dequantized_to_quantized_format_correct,
update_config_for_quantized_format,
)
assert callable(convert_dequantized_to_quantized_format_correct)
assert callable(update_config_for_quantized_format)
except ImportError as e:
pytest.fail(f"Failed to import GPT-OSS utilities: {e}")
🤖 Prompt for AI Agents
In `@regression_tests/test_gpt_oss_conversion_accuracy.py` around lines 418 - 429,
The file defines two functions named test_inference causing the parameterless
one to shadow the real helper and trigger TypeError; rename the
import-validation function (the parameterless test that imports
convert_dequantized_to_quantized_format_correct and
update_config_for_quantized_format) to a unique name such as
test_inference_imports or test_gpt_oss_utils_imports, leave its body unchanged,
and ensure no other code references the old name so test_inference(model,
tokenizer, stage_name) remains the callable inference helper used by
test_conversion_accuracy and other tests.

def __init__(self, margin_deg: float = 1.0):
self.margin_deg = margin_deg
self.metrics: Dict[str, Dict[str, OrthogonalityMetrics]] = {}
self.metrics: dict[str, dict[str, OrthogonalityMetrics]] = {}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Same type annotation mismatch as in tests/test_utils/orthogonality.py.

The stored values are plain dicts, not OrthogonalityMetrics instances. See the comment on tests/test_utils/orthogonality.py line 32 for suggested fixes.

🤖 Prompt for AI Agents
In `@regression_tests/test_osft_orthogonalization.py` at line 50, The type
annotation for self.metrics incorrectly declares the inner values as
OrthogonalityMetrics; update the annotation to reflect that the stored values
are plain nested dicts (not instances). Replace the declaration of self.metrics
with a nested-dict type (for example use dict[str, dict[str, dict[str, float]]]
or a more permissive dict[str, dict[str, Any]]/Mapping if values vary) so the
annotation matches the actual stored structure and reference the existing symbol
self.metrics and the OrthogonalityMetrics type when making the change.

def __init__(self, margin_deg: float = 1.0):
self.margin_deg = margin_deg
self.metrics: Dict[str, Dict[str, OrthogonalityMetrics]] = {}
self.metrics: dict[str, dict[str, OrthogonalityMetrics]] = {}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Incorrect type annotation: stored values are plain dicts, not OrthogonalityMetrics.

The type hint says dict[str, dict[str, OrthogonalityMetrics]], but update() (lines 46-51) stores plain dicts with string keys, not OrthogonalityMetrics instances. This mismatch will cause type checker errors.

Suggested fix

Either fix the type annotation to match actual usage:

-        self.metrics: dict[str, dict[str, OrthogonalityMetrics]] = {}
+        self.metrics: dict[str, dict[str, str | float | int]] = {}

Or better, actually use the dataclass:

-        self.metrics: dict[str, dict[str, OrthogonalityMetrics]] = {}
+        self.metrics: dict[str, OrthogonalityMetrics] = {}

Then in update():

         if key not in self.metrics:
-            self.metrics[key] = {
-                "param_name": param_name,
-                "check_type": check_type,
-                "max_angle_diff": max_angle_diff,
-                "step": step,
-            }
+            self.metrics[key] = OrthogonalityMetrics(
+                param_name=param_name,
+                check_type=check_type,
+                max_angle_diff=max_angle_diff,
+                step=step,
+            )
         else:
             # Update if this is worse
-            if max_angle_diff > self.metrics[key]["max_angle_diff"]:
-                self.metrics[key]["max_angle_diff"] = max_angle_diff
-                self.metrics[key]["step"] = step
+            if max_angle_diff > self.metrics[key].max_angle_diff:
+                self.metrics[key].max_angle_diff = max_angle_diff
+                self.metrics[key].step = step
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self.metrics: dict[str, dict[str, OrthogonalityMetrics]] = {}
self.metrics: dict[str, dict[str, str | float | int]] = {}
🤖 Prompt for AI Agents
In `@tests/test_utils/orthogonality.py` at line 32, The type annotation for
self.metrics is incorrect: it declares dict[str, dict[str,
OrthogonalityMetrics]] but the update() method stores plain dicts (string
keys/values) rather than OrthogonalityMetrics instances; fix by either changing
the annotation to match stored data (e.g., dict[str, dict[str, str]] or
dict[str, dict[str, Any]] as appropriate) or alter update() to construct and
store OrthogonalityMetrics dataclass instances instead of plain dicts; locate
the self.metrics declaration and the update() method in
tests/test_utils/orthogonality.py and make the annotation and stored value types
consistent (referencing self.metrics and update()).

@RobotSail RobotSail merged commit b5115a0 into Red-Hat-AI-Innovation-Team:main Feb 5, 2026
10 of 11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant