Skip to content

Conversation

@nsrawat0333
Copy link

Problem

Issue #603 reported NameError: unbound axis name: i when running MMV (Self-Supervised MultiModal Versatile Networks). This error occurs because the model uses cross-replica batch normalization expecting a pmap axis named 'i', but this axis only exists in multi-device contexts, not single-device runs.

Root Cause Analysis

The error happens in normalization.py where:

  1. Cross-replica batch norm defaults to axis name 'i'
  2. JAX pmap creates this axis in multi-device setups
  3. Single-device runs don't have this axis defined
  4. Results in runtime error during model initialization

Solution

Implemented comprehensive fix for both single and multi-device scenarios:

🔧 Smart Normalization Layer (normalization.py)

# Before: Always assumes axis 'i' exists
kwargs['cross_replica_axis'] = 'i'

# After: Intelligent device detection
if len(jax.devices()) > 1:
    kwargs['cross_replica_axis'] = 'i'  # Multi-device
else:
    kwargs['cross_replica_axis'] = None  # Single-device

- Update aiohttp to address potential security vulnerabilities
- Maintains compatibility with existing codebase
- Addresses dependency security recommendations
…gle-deepmind#603

- Fix 'unbound axis name: i' error in cross-replica batch normalization
- Update normalization.py to handle single-device vs multi-device contexts
- Add run_mmv_eval.py wrapper script for robust evaluation
- Update requirements.txt with compatible version specifications
- Add comprehensive troubleshooting guide to README

Addresses Issue google-deepmind#603: MMV's running problem - NameError: unbound axis name: i

The error occurs because:
1. MMV uses cross-replica batch norm with axis name 'i'
2. This axis only exists in pmap (multi-device) contexts
3. Single-device runs don't have this axis defined

Solutions provided:
1. Smart axis detection in normalization layer
2. Environment configuration script
3. Single-device mode forcing
4. Better dependency management
5. Comprehensive error handling

Users can now run MMV on both single and multi-device setups without
manual JAX configuration or code modifications.
@polarbe
Copy link

polarbe commented Aug 10, 2025 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants