diff --git a/arrayjit/lib/assignments.ml b/arrayjit/lib/assignments.ml index 86f20976..cf4d1857 100644 --- a/arrayjit/lib/assignments.ml +++ b/arrayjit/lib/assignments.ml @@ -282,12 +282,28 @@ let%track4_sexp to_low_level code = |> List.concat |> Set.of_list (module Indexing.Symbol) in - let basecase block_iters rev_iters concat_offsets = + let iter_sizes = + Array.fold2_exn projections.product_space projections.product_iterators + ~init:(Map.empty (module Indexing.Symbol)) + ~f:(fun acc ds its -> + List.fold2_exn ds its ~init:acc ~f:(fun acc d iter -> + Map.set acc ~key:iter ~data:d)) + in + let concat_offset_for syms active = + let _, offset = + List.fold syms ~init:(0, None) ~f:(fun (cumul, found) s -> + let size = Map.find iter_sizes s |> Option.value ~default:0 in + if Indexing.equal_symbol s active then (cumul + size, Some cumul) + else (cumul + size, found)) + in + Option.value ~default:0 offset + in + let basecase block_iters rev_iters = (* Create a substitution from product iterators to loop iterators. Fresh loop symbols are needed because product_iterators may be shared across different operations/tensors, but each lowered operation needs private loop symbols to avoid conflicts in low_level.ml's symbol-to-tensor tracking. - concat_offsets: maps iterator symbol to its cumulative offset within a Concat group. *) + Concat offsets are computed per Concat index using symbol order. *) let exception Empty_block in let block_iters = Array.of_list_rev block_iters in let subst_map = @@ -319,13 +335,13 @@ let%track4_sexp to_low_level code = Indexing.Affine { symbols; offset } | Indexing.Concat syms -> (* For Block lowering: find the active component (in block_iters) and resolve to it - with the appropriate offset from concat_offsets. *) + with the appropriate offset based on Concat symbol order. *) let active = List.find_mapi syms ~f:(fun _i s -> if Array.mem ~equal:Indexing.equal_symbol block_iters s then match Map.find subst_map s with | Some (Indexing.Iterator s') -> - let offset = Map.find concat_offsets s |> Option.value ~default:0 in + let offset = concat_offset_for syms s in Some (s', offset) | _ -> None else None) @@ -334,14 +350,9 @@ let%track4_sexp to_low_level code = | Some (s', 0) -> Indexing.Iterator s' | Some (s', offset) -> Indexing.Affine { symbols = [ (1, s') ]; offset } | None -> - (* No active component - this shouldn't happen in Block lowering *) - let syms' = - List.map syms ~f:(fun s -> - match Map.find subst_map s with - | Some (Indexing.Iterator s') -> s' - | _ -> s) - in - Indexing.Concat syms') + raise + @@ Utils.User_error + "Concat index could not be resolved to an active component during Block lowering") in try let lhs_idcs : Indexing.axis_index array = @@ -368,19 +379,10 @@ let%track4_sexp to_low_level code = else set lhs lhs_idcs @@ apply_op (Ops.Binop accum) [| lhs_ll; rhs2 |] with Empty_block -> Low_level.Noop in - let rec for_loop block_iters rev_iters concat_offsets = function - | [] -> basecase block_iters rev_iters concat_offsets + let rec for_loop block_iters rev_iters = function + | [] -> basecase block_iters rev_iters | (ds, its) :: product -> let index = Indexing.get_symbol () in - (* Build cumulative offsets for this group (for Concat resolution) *) - let _, offsets_for_group = - List.fold2_exn ds its ~init:(0, []) ~f:(fun (cumul, acc) d iter -> - (cumul + d, (iter, cumul) :: acc)) - in - let concat_offsets' = - List.fold offsets_for_group ~init:concat_offsets ~f:(fun m (iter, offset) -> - Map.set m ~key:iter ~data:offset) - in Low_level.unflat_lines @@ List.map2_exn ds its ~f:(fun d iter -> Low_level.For_loop @@ -388,13 +390,12 @@ let%track4_sexp to_low_level code = index; from_ = 0; to_ = d - 1; - body = - for_loop (iter :: block_iters) (index :: rev_iters) concat_offsets' product; + body = for_loop (iter :: block_iters) (index :: rev_iters) product; trace_it = true; }) in let for_loops = - for_loop [] [] (Map.empty (module Indexing.Symbol)) + for_loop [] [] (Array.to_list @@ Array.zip_exn projections.product_space projections.product_iterators) in (* Need initialization if: initialize_neutral is true AND (not surjective OR not injective) @@ -452,7 +453,23 @@ let%track4_sexp to_low_level code = initialize_neutral && not (Indexing.is_surjective proj && Indexing.is_injective proj)) in - let basecase block_iters rev_iters concat_offsets = + let iter_sizes = + Array.fold2_exn projections.product_space projections.product_iterators + ~init:(Map.empty (module Indexing.Symbol)) + ~f:(fun acc ds its -> + List.fold2_exn ds its ~init:acc ~f:(fun acc d iter -> + Map.set acc ~key:iter ~data:d)) + in + let concat_offset_for syms active = + let _, offset = + List.fold syms ~init:(0, None) ~f:(fun (cumul, found) s -> + let size = Map.find iter_sizes s |> Option.value ~default:0 in + if Indexing.equal_symbol s active then (cumul + size, Some cumul) + else (cumul + size, found)) + in + Option.value ~default:0 offset + in + let basecase block_iters rev_iters = let exception Empty_block in let block_iters = Array.of_list_rev block_iters in let subst_map = @@ -488,7 +505,7 @@ let%track4_sexp to_low_level code = if Array.mem ~equal:Indexing.equal_symbol block_iters s then match Map.find subst_map s with | Some (Indexing.Iterator s') -> - let offset = Map.find concat_offsets s |> Option.value ~default:0 in + let offset = concat_offset_for syms s in Some (s', offset) | _ -> None else None) @@ -497,13 +514,9 @@ let%track4_sexp to_low_level code = | Some (s', 0) -> Indexing.Iterator s' | Some (s', offset) -> Indexing.Affine { symbols = [ (1, s') ]; offset } | None -> - let syms' = - List.map syms ~f:(fun s -> - match Map.find subst_map s with - | Some (Indexing.Iterator s') -> s' - | _ -> s) - in - Indexing.Concat syms') + raise + @@ Utils.User_error + "Concat index could not be resolved to an active component during Rev_sides lowering") in let target_tn_exn = function | Node tn -> tn @@ -535,19 +548,10 @@ let%track4_sexp to_low_level code = else set target_tn lhs_idcs @@ apply_op (Ops.Binop accum) [| get target_buf lhs_idcs; rhs2 |] with Empty_block -> Low_level.Noop in - let rec for_loop block_iters rev_iters concat_offsets = function - | [] -> basecase block_iters rev_iters concat_offsets + let rec for_loop block_iters rev_iters = function + | [] -> basecase block_iters rev_iters | (ds, its) :: product -> let index = Indexing.get_symbol () in - (* Build cumulative offsets for this group (for Concat resolution) *) - let _, offsets_for_group = - List.fold2_exn ds its ~init:(0, []) ~f:(fun (cumul, acc) d iter -> - (cumul + d, (iter, cumul) :: acc)) - in - let concat_offsets' = - List.fold offsets_for_group ~init:concat_offsets ~f:(fun m (iter, offset) -> - Map.set m ~key:iter ~data:offset) - in Low_level.unflat_lines @@ List.map2_exn ds its ~f:(fun d iter -> Low_level.For_loop @@ -555,13 +559,12 @@ let%track4_sexp to_low_level code = index; from_ = 0; to_ = d - 1; - body = - for_loop (iter :: block_iters) (index :: rev_iters) concat_offsets' product; + body = for_loop (iter :: block_iters) (index :: rev_iters) product; trace_it = true; }) in let for_loops = - for_loop [] [] (Map.empty (module Indexing.Symbol)) + for_loop [] [] (Array.to_list @@ Array.zip_exn projections.product_space projections.product_iterators) in let neutral_value = Ops.neutral_elem accum in diff --git a/tensor/row.ml b/tensor/row.ml index 39de749a..cea1c7f5 100644 --- a/tensor/row.ml +++ b/tensor/row.ml @@ -4207,26 +4207,41 @@ let%debug4_sexp solve_proj_equations (eqs : proj_equation list) let target_repr, _ = Utils.union_find ~equal:Proj_id.equal !proj_classes ~key:target_pid ~rank:0 in - if not (Map.mem !projs target_repr) then ( - let syms = - List.filter_map proj_dims ~f:(fun (pid, { d; _ }) -> - let repr, _ = - Utils.union_find ~equal:Proj_id.equal !proj_classes ~key:pid ~rank:0 - in - match Map.find !projs repr with - | Some (Idx.Iterator s) -> Some s - | Some (Idx.Fixed_idx 0) when d = 0 -> None (* d=0 is invalid dimension, skip *) - | _ when d = 0 -> None - | _ -> - raise - @@ Shape_error - ( [%string - "Concat component projection %{pid#Proj_id} (d=%{d#Int}) has no iterator"], - [] )) - in - projs := - Map.set !projs ~key:target_repr - ~data:(if List.is_empty syms then Idx.Fixed_idx 0 else Idx.Concat syms))); + let syms = + List.filter_map proj_dims ~f:(fun (pid, { d; _ }) -> + let repr, _ = + Utils.union_find ~equal:Proj_id.equal !proj_classes ~key:pid ~rank:0 + in + match Map.find !projs repr with + | Some (Idx.Iterator s) -> Some s + | Some (Idx.Fixed_idx 0) when d = 0 -> None (* d=0 is invalid dimension, skip *) + | _ when d = 0 -> None + | _ -> + raise + @@ Shape_error + ( [%string + "Concat component projection %{pid#Proj_id} (d=%{d#Int}) has no iterator"], + [] )) + in + let expected_idx = + if List.is_empty syms then Idx.Fixed_idx 0 else Idx.Concat syms + in + match Map.find !projs target_repr with + | None -> projs := Map.set !projs ~key:target_repr ~data:expected_idx + | Some existing_idx -> + let ok = + Idx.equal_axis_index existing_idx expected_idx + || + match (existing_idx, expected_idx) with + | Idx.Iterator s, Idx.Concat [ s' ] when Idx.equal_symbol s s' -> true + | _ -> false + in + if not ok then + raise + @@ Shape_error + ( [%string + "Concat target projection %{target_pid#Proj_id} conflicts with existing index"], + [ Index_mismatch [ existing_idx; expected_idx ] ] )); { v_env; diff --git a/tensor/shape.ml b/tensor/shape.ml index a3d544e1..b63a979a 100644 --- a/tensor/shape.ml +++ b/tensor/shape.ml @@ -1887,13 +1887,20 @@ let%debug4_sexp derive_projections (update_step : update_step) : unit = List.filter_map all_dims ~f:(fun dim -> match dim with | Row.Concat _ -> Some (Row.get_dim_index proj_env dim) + | Dim { proj_id = Some _; _ } -> ( + match Row.get_product_proj proj_env dim with + | Some _ -> Some (Row.get_dim_index proj_env dim) + | None -> ( + (* Also check if dim's projection maps to Idx.Concat *) + try + match Row.get_dim_index proj_env dim with + | Idx.Concat _ as idx -> Some idx + | _ -> None + with _ -> None)) | _ -> ( match Row.get_product_proj proj_env dim with | Some _ -> Some (Row.get_dim_index proj_env dim) - | None -> - (* Also check if dim has a proj_id that maps to Idx.Concat *) - let idx = Row.get_dim_index proj_env dim in - (match idx with Idx.Concat _ -> Some idx | _ -> None))) + | None -> None)) in let concat_groups : Idx.symbol list list = List.filter_map product_indices ~f:(function Idx.Concat syms -> Some syms | _ -> None)