@@ -1944,6 +1944,8 @@ def __init__(
19441944 self .fixed_config = fixed_config
19451945 super ().__init__ (tiling , ** kwargs )
19461946 self .cse = TritonCSE (self .newvar_prefix , self .suffix )
1947+ # Cache of values that can be reused for the prologue.
1948+ self .prologue_cache : dict [str , str ] = {}
19471949 self .prologue : IndentedBuffer = IndentedBuffer ()
19481950 self .post_loop_combine : IndentedBuffer = IndentedBuffer ()
19491951 self .post_loop_store : IndentedBuffer = IndentedBuffer ()
@@ -2485,42 +2487,49 @@ def codegen_block_ptr(
24852487 and self .range_trees [- 1 ].is_loop
24862488 and indexing .has_rindex ()
24872489 ) or indexing .can_lift :
2488- block_descriptor_id = next (self .block_ptr_id )
2489- if isinstance (indexing , BlockPtrOptions ):
2490- block_descriptor = f"block_ptr{ block_descriptor_id } "
2490+ if indexing .can_lift and var in self .prologue_cache :
2491+ # Check for epilogue subtiling to reuse the same
2492+ # tensor descriptor.
2493+ block_descriptor = self .prologue_cache [var ]
24912494 else :
2492- block_descriptor = f"tma_descriptor{ block_descriptor_id } "
2493- line_body = DeferredLine (
2494- name , f"{ block_descriptor } = { indexing .format (var , roffset = False )} "
2495- )
2496- if indexing .can_lift :
2497- self .prologue .writeline (line_body )
2498- else :
2499- self .body .writeline (line_body )
2500-
2501- if isinstance (indexing , BlockPtrOptions ):
2502- # Store for later use. If the buffer is removed the below advancements
2503- # are no longer necessary
2504- self .block_ptr_to_buffer [block_descriptor ] = name
2495+ block_descriptor_id = next (self .block_ptr_id )
2496+ if isinstance (indexing , BlockPtrOptions ):
2497+ block_descriptor = f"block_ptr{ block_descriptor_id } "
2498+ else :
2499+ block_descriptor = f"tma_descriptor{ block_descriptor_id } "
2500+ line_body = DeferredLine (
2501+ name , f"{ block_descriptor } = { indexing .format (var , roffset = False )} "
2502+ )
2503+ if indexing .can_lift :
2504+ self .prologue .writeline (line_body )
2505+ # Cache the descriptor for epilogue subtiling
2506+ self .prologue_cache [var ] = block_descriptor
2507+ else :
2508+ self .body .writeline (line_body )
25052509
2506- # Generate block pointer advancements, for later use.
2507- for symt in TritonSymbols .reduction_types :
2508- advance_offsets = indexing .advance_roffset (symt )
2510+ if isinstance (indexing , BlockPtrOptions ):
2511+ # Store for later use. If the buffer is removed the below advancements
2512+ # are no longer necessary
2513+ self .block_ptr_to_buffer [block_descriptor ] = name
2514+
2515+ # Generate block pointer advancements, for later use.
2516+ for symt in TritonSymbols .reduction_types :
2517+ advance_offsets = indexing .advance_roffset (symt )
2518+
2519+ # Ignore identity advancements.
2520+ if all (
2521+ V .graph .sizevars .statically_known_equals (
2522+ offset , sympy .Integer (0 )
2523+ )
2524+ for offset in advance_offsets
2525+ ):
2526+ continue
25092527
2510- # Ignore identity advancements.
2511- if all (
2512- V .graph .sizevars .statically_known_equals (
2513- offset , sympy .Integer (0 )
2528+ advancements = self .pointer_advancements [symt ]
2529+ assert block_descriptor not in advancements , (
2530+ f"duplicate advancement for pointer '{ block_descriptor } ' at type '{ symt } '"
25142531 )
2515- for offset in advance_offsets
2516- ):
2517- continue
2518-
2519- advancements = self .pointer_advancements [symt ]
2520- assert block_descriptor not in advancements , (
2521- f"duplicate advancement for pointer '{ block_descriptor } ' at type '{ symt } '"
2522- )
2523- advancements [block_descriptor ] = advance_offsets
2532+ advancements [block_descriptor ] = advance_offsets
25242533 else :
25252534 block_descriptor = indexing .format (var )
25262535 return block_descriptor , other
@@ -3879,6 +3888,7 @@ def codegen_prologue(self, code: IndentedBuffer):
38793888
38803889 code .splice (self .prologue )
38813890 self .prologue .clear ()
3891+ self .prologue_cache .clear ()
38823892
38833893 def codegen_body (self ):
38843894 """
0 commit comments