Skip to content

Fix concat lowering for unit dims#424

Merged
lukstafi merged 1 commit intomasterfrom
codex-ambiguous-concat
Jan 12, 2026
Merged

Fix concat lowering for unit dims#424
lukstafi merged 1 commit intomasterfrom
codex-ambiguous-concat

Conversation

@lukstafi
Copy link
Collaborator

Ensure concat components with d=1 keep iterators during projection inference, and gate concat block lowering by the active concat symbol so only one RHS/target is viable per iteration. This prevents the "Ambiguous indices in concatenation" error for 3-way concat with a unit-sized component while keeping two-way behavior intact.

Ensure concat components with d=1 keep iterators during projection inference, and gate concat
block lowering by the active concat symbol so only one RHS/target is viable per iteration. This
prevents the "Ambiguous indices in concatenation" error for 3-way concat with a unit-sized
component while keeping two-way behavior intact.

Co-authored-by: Codex <codex@openai.com>
@claude
Copy link
Contributor

claude bot commented Jan 12, 2026

Claude finished @lukstafi's task —— View job


PR Review: Fix concat lowering for unit dims

I've analyzed this PR and can provide comprehensive feedback on the implementation that fixes the "Ambiguous indices in concatenation" error for 3-way concatenation with unit dimensions.

Todo List:

  • Read project guidelines from CLAUDE.md
  • Analyze the core implementation changes in arrayjit/assignments.ml
  • Review tensor/row.ml and tensor/shape.ml changes
  • Examine test changes and coverage
  • Assess code quality and identify potential issues
  • Check for performance considerations
  • Provide final review feedback

Code Quality and Design

Overall Assessment: Good

The fix addresses a real bug where 3-way concatenation with unit dimensions failed during backpropagation. The approach is sound:

  1. Root Cause: In assignments.ml:557, the filtering for Rev_sides (gradient computation) wasn't properly excluding inactive RHS targets, causing multiple targets to be valid for the same position.

  2. Solution Strategy: Gate concatenation block lowering by active concat symbols, ensuring only one RHS/target is viable per iteration.

Specific Code Analysis

arrayjit/lib/assignments.ml (assignments.ml:306-372 and assignments.ml:488-557)

Quality: Good

let concat_syms_opt = 
  match Array.filter_map projections.project_lhs ~f:(function
    | Indexing.Concat syms -> Some syms | _ -> None) with
  | [| syms |] when List.length syms = Array.length rhses -> Some (Array.of_list syms)
  | _ -> None

Strengths:

  • Clean pattern matching to extract concat symbols
  • Conservative approach: only applies optimization when there's exactly one Concat with matching symbol count
  • Symmetric implementation for both Block and Rev_sides cases

Concerns:

  • The logic is duplicated between two functions (lines ~315 and ~495). Consider extracting a helper function:
let extract_concat_symbols projections rhses_length =
  match Array.filter_map projections.project_lhs ~f:(function
    | Indexing.Concat syms -> Some syms | _ -> None) with
  | [| syms |] when List.length syms = rhses_length -> Some (Array.of_list syms)
  | _ -> None

tensor/row.ml (row.ml:4105-4280)

Quality: Good with minor concerns ⚠️

The key changes handle unit dimensions (d=1) specially in concat components:

if Idx.equal_axis_index idx (Idx.Fixed_idx 0) && Set.mem concat_reprs repr then ()

Strengths:

  • Correctly identifies that unit dimensions in concat components should not raise conflicts
  • Ensures concat components get iterators even when d=1

Concerns:

  • Magic number: Fixed_idx 0 appears to be special-cased but lacks documentation explaining why index 0 specifically is exempt from conflict checking for concat representatives
  • Complexity: The logic around concat_reprs and concat_dims adds complexity but appears necessary

tensor/shape.ml (shape.ml:1877-1930)

Quality: Good

let symbol_to_proj = Map.of_alist_exn (module Idx.Symbol)
  (Row.product_dim_iterators proj_env |> List.map ~f:(fun (p, d, s) -> (s, (p, d))))

Strengths:

  • Elegant mapping from symbols to projections
  • Ensures concat component iterators are preserved even with dim=1

Test Coverage

Quality: Excellent

The test coverage is comprehensive:

  1. 2-way concatenation: Baseline case that was already working
  2. 3-way concatenation: Tests the main fix (different sized components)
  3. 3-way with unit dimension: Tests the specific edge case (one component has dim=1)

The test structure is well-organized and provides clear output showing forward pass, backward pass, and gradients.

Potential Issues

Security Concerns

None identified ✅ - This is pure computational logic with no external dependencies or unsafe operations.

Performance Considerations

Minor Impact ⚠️

  • Additional set operations and map lookups in hot paths (solve_proj_equations)
  • The concat_reprs set creation adds overhead but only when concat operations are present
  • The filtering logic in assignments.ml adds minor computational cost per iteration

Recommendation: Consider profiling concat-heavy workloads to ensure no significant regression.

Edge Cases and Robustness

Good with caveats ⚠️

  1. Pattern matching exhaustiveness: The code assumes specific projection structures. Consider what happens if:

    • Multiple Concat projections exist simultaneously
    • Concat symbols don't match expected array lengths
    • Mixed concatenation patterns occur
  2. Error handling: The code silently falls back to None for concat_syms_opt. This is probably correct, but could benefit from debug logging.

Suggestions for Improvement

  1. Extract duplicate logic (assignments.ml:315 & 495):

    let extract_concat_symbols projections rhses_count =
      match Array.filter_map projections.project_lhs ~f:(function
        | Indexing.Concat syms -> Some syms | _ -> None) with
      | [| syms |] when List.length syms = rhses_count -> Some (Array.of_list syms)
      | _ -> None
  2. Add debug logging for fallback cases:

    | _ -> 
      [%debug4_sexp "No concat optimization applied", (projections, Array.length rhses)];
      None
  3. Document the Fixed_idx 0 exception in row.ml - why is index 0 special for concat representatives?

Final Verdict

Approve with minor suggestions

This is a solid fix that:

  • ✅ Addresses a real bug affecting 3-way concatenation
  • ✅ Maintains backward compatibility
  • ✅ Has comprehensive test coverage
  • ✅ Follows project conventions
  • ✅ Implements a clean, targeted solution

The code quality is good overall. The main suggestions are code deduplication and better documentation of the special case handling.


@lukstafi lukstafi merged commit 9517730 into master Jan 12, 2026
1 of 5 checks passed
@lukstafi lukstafi deleted the codex-ambiguous-concat branch January 12, 2026 18:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant