Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 52 additions & 49 deletions arrayjit/lib/assignments.ml
Original file line number Diff line number Diff line change
Expand Up @@ -282,12 +282,28 @@ let%track4_sexp to_low_level code =
|> List.concat
|> Set.of_list (module Indexing.Symbol)
in
let basecase block_iters rev_iters concat_offsets =
let iter_sizes =
Array.fold2_exn projections.product_space projections.product_iterators
~init:(Map.empty (module Indexing.Symbol))
~f:(fun acc ds its ->
List.fold2_exn ds its ~init:acc ~f:(fun acc d iter ->
Map.set acc ~key:iter ~data:d))
in
let concat_offset_for syms active =
let _, offset =
List.fold syms ~init:(0, None) ~f:(fun (cumul, found) s ->
let size = Map.find iter_sizes s |> Option.value ~default:0 in
if Indexing.equal_symbol s active then (cumul + size, Some cumul)
else (cumul + size, found))
in
Option.value ~default:0 offset
in
let basecase block_iters rev_iters =
(* Create a substitution from product iterators to loop iterators. Fresh loop symbols are
needed because product_iterators may be shared across different operations/tensors, but
each lowered operation needs private loop symbols to avoid conflicts in low_level.ml's
symbol-to-tensor tracking.
concat_offsets: maps iterator symbol to its cumulative offset within a Concat group. *)
Concat offsets are computed per Concat index using symbol order. *)
let exception Empty_block in
let block_iters = Array.of_list_rev block_iters in
let subst_map =
Expand Down Expand Up @@ -319,13 +335,13 @@ let%track4_sexp to_low_level code =
Indexing.Affine { symbols; offset }
| Indexing.Concat syms ->
(* For Block lowering: find the active component (in block_iters) and resolve to it
with the appropriate offset from concat_offsets. *)
with the appropriate offset based on Concat symbol order. *)
let active =
List.find_mapi syms ~f:(fun _i s ->
if Array.mem ~equal:Indexing.equal_symbol block_iters s then
match Map.find subst_map s with
| Some (Indexing.Iterator s') ->
let offset = Map.find concat_offsets s |> Option.value ~default:0 in
let offset = concat_offset_for syms s in
Some (s', offset)
| _ -> None
else None)
Expand All @@ -334,14 +350,9 @@ let%track4_sexp to_low_level code =
| Some (s', 0) -> Indexing.Iterator s'
| Some (s', offset) -> Indexing.Affine { symbols = [ (1, s') ]; offset }
| None ->
(* No active component - this shouldn't happen in Block lowering *)
let syms' =
List.map syms ~f:(fun s ->
match Map.find subst_map s with
| Some (Indexing.Iterator s') -> s'
| _ -> s)
in
Indexing.Concat syms')
raise
@@ Utils.User_error
"Concat index could not be resolved to an active component during Block lowering")
in
try
let lhs_idcs : Indexing.axis_index array =
Expand All @@ -368,33 +379,23 @@ let%track4_sexp to_low_level code =
else set lhs lhs_idcs @@ apply_op (Ops.Binop accum) [| lhs_ll; rhs2 |]
with Empty_block -> Low_level.Noop
in
let rec for_loop block_iters rev_iters concat_offsets = function
| [] -> basecase block_iters rev_iters concat_offsets
let rec for_loop block_iters rev_iters = function
| [] -> basecase block_iters rev_iters
| (ds, its) :: product ->
let index = Indexing.get_symbol () in
(* Build cumulative offsets for this group (for Concat resolution) *)
let _, offsets_for_group =
List.fold2_exn ds its ~init:(0, []) ~f:(fun (cumul, acc) d iter ->
(cumul + d, (iter, cumul) :: acc))
in
let concat_offsets' =
List.fold offsets_for_group ~init:concat_offsets ~f:(fun m (iter, offset) ->
Map.set m ~key:iter ~data:offset)
in
Low_level.unflat_lines
@@ List.map2_exn ds its ~f:(fun d iter ->
Low_level.For_loop
{
index;
from_ = 0;
to_ = d - 1;
body =
for_loop (iter :: block_iters) (index :: rev_iters) concat_offsets' product;
body = for_loop (iter :: block_iters) (index :: rev_iters) product;
trace_it = true;
})
in
let for_loops =
for_loop [] [] (Map.empty (module Indexing.Symbol))
for_loop [] []
(Array.to_list @@ Array.zip_exn projections.product_space projections.product_iterators)
in
(* Need initialization if: initialize_neutral is true AND (not surjective OR not injective)
Expand Down Expand Up @@ -452,7 +453,23 @@ let%track4_sexp to_low_level code =
initialize_neutral
&& not (Indexing.is_surjective proj && Indexing.is_injective proj))
in
let basecase block_iters rev_iters concat_offsets =
let iter_sizes =
Array.fold2_exn projections.product_space projections.product_iterators
~init:(Map.empty (module Indexing.Symbol))
~f:(fun acc ds its ->
List.fold2_exn ds its ~init:acc ~f:(fun acc d iter ->
Map.set acc ~key:iter ~data:d))
in
let concat_offset_for syms active =
let _, offset =
List.fold syms ~init:(0, None) ~f:(fun (cumul, found) s ->
let size = Map.find iter_sizes s |> Option.value ~default:0 in
if Indexing.equal_symbol s active then (cumul + size, Some cumul)
else (cumul + size, found))
in
Option.value ~default:0 offset
in
let basecase block_iters rev_iters =
let exception Empty_block in
let block_iters = Array.of_list_rev block_iters in
let subst_map =
Expand Down Expand Up @@ -488,7 +505,7 @@ let%track4_sexp to_low_level code =
if Array.mem ~equal:Indexing.equal_symbol block_iters s then
match Map.find subst_map s with
| Some (Indexing.Iterator s') ->
let offset = Map.find concat_offsets s |> Option.value ~default:0 in
let offset = concat_offset_for syms s in
Some (s', offset)
| _ -> None
else None)
Expand All @@ -497,13 +514,9 @@ let%track4_sexp to_low_level code =
| Some (s', 0) -> Indexing.Iterator s'
| Some (s', offset) -> Indexing.Affine { symbols = [ (1, s') ]; offset }
| None ->
let syms' =
List.map syms ~f:(fun s ->
match Map.find subst_map s with
| Some (Indexing.Iterator s') -> s'
| _ -> s)
in
Indexing.Concat syms')
raise
@@ Utils.User_error
"Concat index could not be resolved to an active component during Rev_sides lowering")
in
let target_tn_exn = function
| Node tn -> tn
Expand Down Expand Up @@ -535,33 +548,23 @@ let%track4_sexp to_low_level code =
else set target_tn lhs_idcs @@ apply_op (Ops.Binop accum) [| get target_buf lhs_idcs; rhs2 |]
with Empty_block -> Low_level.Noop
in
let rec for_loop block_iters rev_iters concat_offsets = function
| [] -> basecase block_iters rev_iters concat_offsets
let rec for_loop block_iters rev_iters = function
| [] -> basecase block_iters rev_iters
| (ds, its) :: product ->
let index = Indexing.get_symbol () in
(* Build cumulative offsets for this group (for Concat resolution) *)
let _, offsets_for_group =
List.fold2_exn ds its ~init:(0, []) ~f:(fun (cumul, acc) d iter ->
(cumul + d, (iter, cumul) :: acc))
in
let concat_offsets' =
List.fold offsets_for_group ~init:concat_offsets ~f:(fun m (iter, offset) ->
Map.set m ~key:iter ~data:offset)
in
Low_level.unflat_lines
@@ List.map2_exn ds its ~f:(fun d iter ->
Low_level.For_loop
{
index;
from_ = 0;
to_ = d - 1;
body =
for_loop (iter :: block_iters) (index :: rev_iters) concat_offsets' product;
body = for_loop (iter :: block_iters) (index :: rev_iters) product;
trace_it = true;
})
in
let for_loops =
for_loop [] [] (Map.empty (module Indexing.Symbol))
for_loop [] []
(Array.to_list @@ Array.zip_exn projections.product_space projections.product_iterators)
in
let neutral_value = Ops.neutral_elem accum in
Expand Down
55 changes: 35 additions & 20 deletions tensor/row.ml
Original file line number Diff line number Diff line change
Expand Up @@ -4207,26 +4207,41 @@ let%debug4_sexp solve_proj_equations (eqs : proj_equation list)
let target_repr, _ =
Utils.union_find ~equal:Proj_id.equal !proj_classes ~key:target_pid ~rank:0
in
if not (Map.mem !projs target_repr) then (
let syms =
List.filter_map proj_dims ~f:(fun (pid, { d; _ }) ->
let repr, _ =
Utils.union_find ~equal:Proj_id.equal !proj_classes ~key:pid ~rank:0
in
match Map.find !projs repr with
| Some (Idx.Iterator s) -> Some s
| Some (Idx.Fixed_idx 0) when d = 0 -> None (* d=0 is invalid dimension, skip *)
| _ when d = 0 -> None
| _ ->
raise
@@ Shape_error
( [%string
"Concat component projection %{pid#Proj_id} (d=%{d#Int}) has no iterator"],
[] ))
in
projs :=
Map.set !projs ~key:target_repr
~data:(if List.is_empty syms then Idx.Fixed_idx 0 else Idx.Concat syms)));
let syms =
List.filter_map proj_dims ~f:(fun (pid, { d; _ }) ->
let repr, _ =
Utils.union_find ~equal:Proj_id.equal !proj_classes ~key:pid ~rank:0
in
match Map.find !projs repr with
| Some (Idx.Iterator s) -> Some s
| Some (Idx.Fixed_idx 0) when d = 0 -> None (* d=0 is invalid dimension, skip *)
| _ when d = 0 -> None
| _ ->
raise
@@ Shape_error
( [%string
"Concat component projection %{pid#Proj_id} (d=%{d#Int}) has no iterator"],
[] ))
in
let expected_idx =
if List.is_empty syms then Idx.Fixed_idx 0 else Idx.Concat syms
in
match Map.find !projs target_repr with
| None -> projs := Map.set !projs ~key:target_repr ~data:expected_idx
| Some existing_idx ->
let ok =
Idx.equal_axis_index existing_idx expected_idx
||
match (existing_idx, expected_idx) with
| Idx.Iterator s, Idx.Concat [ s' ] when Idx.equal_symbol s s' -> true
| _ -> false
in
if not ok then
raise
@@ Shape_error
( [%string
"Concat target projection %{target_pid#Proj_id} conflicts with existing index"],
[ Index_mismatch [ existing_idx; expected_idx ] ] ));

{
v_env;
Expand Down
15 changes: 11 additions & 4 deletions tensor/shape.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1887,13 +1887,20 @@ let%debug4_sexp derive_projections (update_step : update_step) : unit =
List.filter_map all_dims ~f:(fun dim ->
match dim with
| Row.Concat _ -> Some (Row.get_dim_index proj_env dim)
| Dim { proj_id = Some _; _ } -> (
match Row.get_product_proj proj_env dim with
| Some _ -> Some (Row.get_dim_index proj_env dim)
| None -> (
(* Also check if dim's projection maps to Idx.Concat *)
try
match Row.get_dim_index proj_env dim with
| Idx.Concat _ as idx -> Some idx
| _ -> None
with _ -> None))
| _ -> (
match Row.get_product_proj proj_env dim with
| Some _ -> Some (Row.get_dim_index proj_env dim)
| None ->
(* Also check if dim has a proj_id that maps to Idx.Concat *)
let idx = Row.get_dim_index proj_env dim in
(match idx with Idx.Concat _ -> Some idx | _ -> None)))
| None -> None))
in
let concat_groups : Idx.symbol list list =
List.filter_map product_indices ~f:(function Idx.Concat syms -> Some syms | _ -> None)
Expand Down
Loading