1515 Output ,
1616 Placeholder ,
1717 Read ,
18+ ReduceOp ,
1819 Reduction ,
1920 Write ,
2021 get_custom ,
3334from ..utils .general_utils import (
3435 get_hardware_constraint ,
3536 get_largest_index_and_size ,
37+ get_workgroup_constraints ,
3638 partial ,
3739)
3840from ..utils .mma_utils import (
@@ -145,12 +147,21 @@ def set_node_indices(
145147 print_trace (trace )
146148
147149 graph_passes = []
148- if mma_mapping != {} :
150+ if mma_mapping :
149151 graph_passes += [
150152 partial (
151153 set_thread_dependent_index_from_mma , constraints , mma_mapping , trace
152154 )
153155 ]
156+ elif reduce_mapping := get_reduce_mapping (trace , constraints ):
157+ graph_passes += [
158+ partial (
159+ set_thread_dependent_index_from_reduce ,
160+ constraints ,
161+ trace ,
162+ reduce_mapping ,
163+ )
164+ ]
154165 else :
155166 graph_passes += [
156167 partial (set_thread_dependent_index_from_read_write , constraints , trace )
@@ -516,9 +527,7 @@ def set_thread_dependent_index_from_read_write(
516527 assert sources , "No read nodes found in the graph."
517528
518529 visited = set ()
519- workgroup_constraints = [
520- c for c in constraints if isinstance (c , WorkgroupConstraint )
521- ]
530+ workgroup_constraints = get_workgroup_constraints (constraints )
522531 symbolic_constraints = [c for c in constraints if isinstance (c , SymbolicAlias )]
523532 for source in sources :
524533 visited = visited .union (set ([x for x in sources ]))
@@ -533,6 +542,138 @@ def set_thread_dependent_index_from_read_write(
533542 )
534543
535544
545+ def get_reduce_mapping (
546+ trace : CapturedTrace , constraints : list [Constraint ]
547+ ) -> dict [ReduceOp , dict [IndexSymbol , IndexSequence ]]:
548+ """
549+ Get the mapping of the reduce ops to the index sequence.
550+
551+ Resulting index will have reduction dim distributed across wg0 threads and
552+ rest of the dims distributed similar to read/write nodes according to the
553+ WorkgroupConstraints.
554+
555+ Example:
556+ ```
557+ constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)]
558+ ...
559+ @tkw.reduction(N, init_args=[init_max, init_sum])
560+ def repeat(
561+ partial_max: tkl.Register[M, tkl.f32],
562+ ) -> tkl.Register[M, tkl.f32]:
563+ res = tkw.read(a) # [M, N]
564+ partial_max = tkw.max(res, partial_max, dim=N) # {N: 2*$T0 : 2 : 1, M: $T1 : 1 : 1}
565+ ...
566+ ```
567+
568+ """
569+ sources = trace .walk (lambda node : isinstance (get_custom (node ), ReduceOp ))
570+ hardware_constraint = get_hardware_constraint (constraints )
571+ workgroup_constraints = get_workgroup_constraints (constraints )
572+
573+ reduce_mapping = {}
574+ for source in sources :
575+ custom = get_custom (source )
576+ index = {}
577+
578+ dim = custom .dim
579+
580+ # Compute the index sequence for the reduction dimension based on the
581+ # threads per wave and the vector size.
582+ threads_per_wave = hardware_constraint .threads_per_wave
583+ vector_size = hardware_constraint .vector_shapes [dim ]
584+ assert (
585+ vector_size % threads_per_wave == 0
586+ ), f"Vector size { dim } ={ vector_size } must be divisible by threads per wave { threads_per_wave } "
587+ elements_per_thread = vector_size // threads_per_wave
588+ stride = compute_stride (
589+ custom .indexing_dims , hardware_constraint .vector_shapes , dim
590+ )
591+ index [dim ] = hardware_constraint .apply_read_write_thread_mapping (
592+ dim , 0 , elements_per_thread , stride
593+ )
594+
595+ for dim in custom .indexing_dims :
596+ elements_per_thread = 1
597+ stride = compute_stride (
598+ custom .indexing_dims , hardware_constraint .vector_shapes , dim
599+ )
600+ wg_constraint = [x for x in workgroup_constraints if x .dim == dim ]
601+ assert (
602+ len (wg_constraint ) <= 1
603+ ), f"Multiple workgroup constraints for dimension { dim } "
604+ if wg_constraint :
605+ workgroup_dim = wg_constraint [0 ].workgroup_dim
606+ else :
607+ continue
608+
609+ index [dim ] = hardware_constraint .apply_read_write_thread_mapping (
610+ dim , workgroup_dim , elements_per_thread , stride
611+ )
612+
613+ reduce_mapping [custom ] = index
614+
615+ return reduce_mapping
616+
617+
618+ def populate_reduce_source_indices (
619+ node : ReduceOp ,
620+ hardware_constraint : HardwareConstraint ,
621+ workgroup_constraints : list [WorkgroupConstraint ],
622+ index : dict [IndexSymbol , IndexSequence ],
623+ ):
624+ """
625+ Populate the source indices for the reduce op.
626+ """
627+ vector_shapes = hardware_constraint .vector_shapes
628+ ret = []
629+ if isinstance (node .arg , Sequence ):
630+ ret += [(get_custom (a ), index , vector_shapes ) for a in node .arg ]
631+ else :
632+ ret += [(get_custom (node .arg ), index , vector_shapes )]
633+
634+ # Reduce args must contain index for the reduction dimension,
635+ # but init and the reduction itself does not.
636+ res_index = copy (index )
637+ del res_index [node .dim ]
638+
639+ if node .init :
640+ ret += [(get_custom (node .init ), res_index , vector_shapes )]
641+
642+ ret += [(node , res_index , vector_shapes )]
643+
644+ return ret
645+
646+
647+ def set_thread_dependent_index_from_reduce (
648+ constraints : Sequence [Constraint ],
649+ trace : CapturedTrace ,
650+ reduce_mapping : dict [ReduceOp , dict [IndexSymbol , IndexSequence ]],
651+ ):
652+ """
653+ Set the thread dependent index, rooting on reduce ops.
654+ """
655+ hardware_constraint = get_hardware_constraint (constraints )
656+ sources = trace .walk (lambda node : isinstance (get_custom (node ), ReduceOp ))
657+ sources = [get_custom (x ) for x in sources ]
658+ assert sources , "No reduce nodes found in the graph."
659+
660+ visited = set ()
661+ workgroup_constraints = get_workgroup_constraints (constraints )
662+ symbolic_constraints = [c for c in constraints if isinstance (c , SymbolicAlias )]
663+ for source in sources :
664+ visited = visited .union (set ([x for x in sources ]))
665+ visited .remove (source )
666+ index = reduce_mapping [source ]
667+ new_sources = populate_reduce_source_indices (
668+ source , hardware_constraint , workgroup_constraints , index
669+ )
670+ visited = propagate_indices (
671+ new_sources ,
672+ visited ,
673+ symbolic_constraints ,
674+ )
675+
676+
536677def set_post_expansion_indices (trace : CapturedTrace , constraints : list [Constraint ]):
537678 """
538679 Add offsets to the indices based on the expanded dims.
0 commit comments