@@ -1682,37 +1682,36 @@ def mul2_clc(
16821682 n_elements ,
16831683 BLOCK_SIZE : tl .constexpr ,
16841684 ):
1685- pid = tl .program_id (axis = 0 )
1686- block_start = pid * BLOCK_SIZE
1687-
1688- tid = tlx .thread_id (axis = 0 )
1689-
1690- offsets = block_start + tl .arange (0 , BLOCK_SIZE )
1691- mask = offsets < n_elements
1685+ tile_id = tl .program_id (axis = 0 )
16921686
1693- x = tl .load (x_ptr + offsets , mask = mask )
1694- y = tl .load (y_ptr + offsets , mask = mask )
1695- output = x * y
1696- tl .store (z_ptr + offsets , output , mask = mask )
1687+ # CLC Init
1688+ clc_phase_producer = 1
1689+ clc_phase_consumer = 0
1690+ # NUM_CLC_STAGES=1
1691+ # NUM_CONSUMERS=1
1692+ clc_context = tlx .clc_create_context (1 , 1 )
16971693
1698- bars = tlx .alloc_barriers (num_barriers = 1 )
1699- clc_mbar = bars [0 ]
1694+ while tile_id != - 1 :
1695+ # CLC producer
1696+ tlx .clc_producer (clc_context , 0 , clc_phase_producer )
1697+ clc_phase_producer ^= 1
17001698
1701- responses = tlx .alloc_clc_responses (num_responses = 1 )
1702- clc_response = tlx .local_view (responses , 0 )
1703- tlx .barrier_expect_bytes (clc_mbar , 16 ) # CLC response is 16-byte
1699+ block_start = tile_id * BLOCK_SIZE
17041700
1705- # Issue async clc.try_cancel for the next available CTA
1706- tlx . clc_issue ( clc_response , clc_mbar )
1701+ offsets = block_start + tl . arange ( 0 , BLOCK_SIZE )
1702+ mask = offsets < n_elements
17071703
1708- # Wait for clc.try_cancel finishes
1709- tlx .barrier_wait (clc_mbar , 0 )
1704+ x = tl .load (x_ptr + offsets , mask = mask )
1705+ y = tl .load (y_ptr + offsets , mask = mask )
1706+ output = x * y
1707+ tl .store (z_ptr + offsets , output , mask = mask )
17101708
1711- # Extract CTA ID from CLC response
1712- res = tlx .clc_query (clc_response )
1709+ # CLC consumer
1710+ tile_id = tlx .clc_consumer (clc_context , 0 , clc_phase_consumer )
1711+ clc_phase_consumer ^= 1
17131712
1714- if tid == 0 :
1715- tl .device_print ("Extracted CtaID" , res )
1713+ if tlx . thread_id ( axis = 0 ) == 0 :
1714+ tl .device_print ("Extracted CtaID" , tile_id )
17161715
17171716 torch .manual_seed (0 )
17181717 # number of kernels to launch in a non-persistent mode
@@ -1731,11 +1730,7 @@ def mul2_clc(
17311730 assert re .search ((r'clusterlaunchcontrol.query_cancel.is_canceled.pred.b128' ), ptx , flags = re .DOTALL )
17321731 assert re .search ((r'clusterlaunchcontrol.query_cancel.get_first_ctaid.v4.b32.b128' ), ptx , flags = re .DOTALL )
17331732
1734- # Each worker uses the {blockIdx.x, blockIdx.y, blockIdx.z} coordinate as the first output tile to process
1735- # and uses the CLC query for subsequent processing of output tiles.
1736- # However in our test those CTAs left from the first round won't execute.
1737- # Its nonzero count MUST be different from original size.
1738- assert (torch .count_nonzero (output ) != size )
1733+ assert (torch .count_nonzero (output ) == size )
17391734
17401735
17411736@pytest .mark .skipif (not is_hopper_or_newer (), reason = "Need Hopper or newer" )
0 commit comments