Skip to content

Conversation

@lhparker1
Copy link
Member

This pull request includes a minor fix to the subsampler.py file. The change ensures that the squeeze method explicitly removes the last dimension (dim=-1) when calculating the scales tensor.

@lhparker1 lhparker1 requested a review from EiffL June 17, 2025 18:51
@EiffL
Copy link
Contributor

EiffL commented Jun 17, 2025

Hum, but is that the problem? @NolanKoblischke issue was that things didn't work if he had a batch size of 1, or no batch size

@EiffL
Copy link
Contributor

EiffL commented Jun 17, 2025

Can we instead add a piece of code in the codecs encode functions to ensure a batch dimension is there?

Copy link
Contributor

@EiffL EiffL left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review: Massive update to subsampler

Overview

This PR makes a targeted fix to the subsampler.py file, changing a generic squeeze() call to explicitly specify squeeze(-1) when calculating the scales tensor. Despite the "massive update" title, this is a minimal but important change.

Analysis

Code Quality & Correctness ✅

  • Good fix: The change from squeeze() to squeeze(-1) is a sound improvement
  • Explicit dimension handling: Specifying the dimension to squeeze makes the code more predictable and safer
  • Maintains functionality: The mathematical operation remains the same but with better dimension control

Following Project Conventions ✅

  • Consistent with PyTorch best practices: Explicit dimension specification is recommended
  • Aligns with codebase style: The change follows the project's tensor operation patterns
  • Type safety: Helps maintain tensor shape consistency with the jaxtyping annotations used throughout the project

Technical Implications

Benefits:

  • Prevents silent errors: Generic squeeze() could remove unintended dimensions if tensor shapes change
  • Better debugging: Explicit dimension specification makes errors more traceable
  • Shape consistency: Ensures the scales tensor always has the expected shape for broadcasting

Potential considerations:

  • Runtime behavior: If the last dimension is not size 1, this will now raise an error instead of silently doing nothing
  • Backward compatibility: Could break code that relied on the previous behavior with unexpected tensor shapes

Context Analysis

Looking at the code:

scales = ((self.dim_in / label_sizes) ** 0.5).squeeze(-1)
  • label_sizes has shape [b, 1] (from keepdim=True)
  • The division and power operations maintain this shape
  • squeeze(-1) removes the last dimension, resulting in shape [b]
  • This is then used for broadcasting: scales[:, None, None, None]

Verdict

LGTM - This is a good defensive programming practice that improves code reliability without changing functionality. The change is solid and should be merged.

Minor suggestion: Consider adding a comment explaining why the last dimension is squeezed, especially since this relates to the broadcasting pattern used later.

@EiffL EiffL merged commit 3941ddf into main Jun 17, 2025
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants