@@ -139,6 +139,12 @@ struct WarpSpecializedLayerNorm
139139 scheduled_tiles++;
140140 // if (blockIdx.x == 0) printf("Pushed tile %d to DMA.\n", tile_id);
141141 }
142+ // #if (defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 12))
143+ // if constexpr (arch::is_major_v<9> || arch::is_major_v<10>)
144+ // {
145+ // cudaTriggerProgrammaticLaunchCompletion();
146+ // }
147+ // #endif
142148 sched2dma_w.push (0xffffffff );
143149 // if (blockIdx.x == 0) printf("Pushed tile -1 to DMA.\n");
144150 if (atomicAdd (&(param.counters ->cta_completion_ctr ), 1 ) == grid_sz - 1 )
@@ -151,6 +157,12 @@ struct WarpSpecializedLayerNorm
151157 else
152158 {
153159 scheduled_tiles = 1 ;
160+ // #if (defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 12))
161+ // if constexpr (arch::is_major_v<9> || arch::is_major_v<10>)
162+ // {
163+ // cudaTriggerProgrammaticLaunchCompletion();
164+ // }
165+ // #endif
154166 }
155167 return scheduled_tiles;
156168 }
@@ -201,25 +213,30 @@ struct WarpSpecializedLayerNorm
201213 }
202214 // if (blockIdx.x == 0) printf("Pushed tile %d to MATH.\n", m_base);
203215
216+ if constexpr (FIRST_RUN)
217+ {
218+ cudaGridDependencySynchronize ();
219+ }
220+ const uint32_t eff_m_block
221+ = std::min (static_cast <uint32_t >(Traits::M_BLOCK), static_cast <uint32_t >(param.m - m_base));
204222 const auto tx
205- = (Traits::M_BLOCK * param.n * sizeof (typename Traits::InputType) * (Traits::RESIDUAL ? 2 : 1 ))
206- + (FIRST_RUN ? sizeof (AuxData) / Traits::N_BLOCK * param.n : 0 );
223+ = (eff_m_block * param.n * sizeof (typename Traits::InputType) * (Traits::RESIDUAL ? 2 : 1 ))
224+ + (FIRST_RUN ? ( sizeof (AuxData) / Traits::N_BLOCK * param.n ) : 0 );
207225
208226 auto vec_buffer_ptr = input_vec_fifo_w.tmaReserve (tx);
209227
210228 // if (blockIdx.x == 0) printf("SMEM buffer ready, start loading tile %d.\n", m_base);
211229
212- if constexpr (FIRST_RUN)
213- {
214- cudaGridDependencySynchronize ();
215- }
216230
217231 for (int i = 0 ; i < Traits::M_BLOCK; i++)
218232 {
219- load_a_vec (¶m.input [(m_base + i) * param.n ],
220- __nvvm_get_smem_pointer (&shared->input_vec [vec_buffer_ptr][0 ][i * Traits::N_BLOCK]),
221- param.n * sizeof (typename Traits::InputType),
222- __nvvm_get_smem_pointer (input_vec_fifo_w.barrier_ptr (vec_buffer_ptr)));
233+ if (i < eff_m_block) [[likely]]
234+ {
235+ load_a_vec (¶m.input [(m_base + i) * param.n ],
236+ __nvvm_get_smem_pointer (&shared->input_vec [vec_buffer_ptr][0 ][i * Traits::N_BLOCK]),
237+ param.n * sizeof (typename Traits::InputType),
238+ __nvvm_get_smem_pointer (input_vec_fifo_w.barrier_ptr (vec_buffer_ptr)));
239+ }
223240 }
224241
225242 // Use templated lambdas to defer resolving the symbols like "param.residual".
@@ -231,10 +248,13 @@ struct WarpSpecializedLayerNorm
231248 {
232249 for (int i = 0 ; i < Traits::M_BLOCK; i++)
233250 {
234- load_a_vec (¶m.residual [(m_base + i) * param.n ],
235- __nvvm_get_smem_pointer (&shared->input_vec [vec_buffer_ptr][1 ][i * Traits::N_BLOCK]),
236- param.n * sizeof (typename Traits::InputType),
237- __nvvm_get_smem_pointer (input_vec_fifo_w.barrier_ptr (vec_buffer_ptr)));
251+ if (i < eff_m_block) [[likely]]
252+ {
253+ load_a_vec (¶m.residual [(m_base + i) * param.n ],
254+ __nvvm_get_smem_pointer (&shared->input_vec [vec_buffer_ptr][1 ][i * Traits::N_BLOCK]),
255+ param.n * sizeof (typename Traits::InputType),
256+ __nvvm_get_smem_pointer (input_vec_fifo_w.barrier_ptr (vec_buffer_ptr)));
257+ }
238258 }
239259 }(param);
240260 }
@@ -423,6 +443,13 @@ struct WarpSpecializedLayerNorm
423443
424444 using FusedOperator = GetFusedOperator<typename Traits::FusedOperator>;
425445
446+ #if (defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 12))
447+ if constexpr (arch::is_major_v<9 > || arch::is_major_v<10 >)
448+ {
449+ // Ensure upstream kernel writes are visible before reading dependent activation/residual data.
450+ cudaGridDependencySynchronize ();
451+ }
452+ #endif
426453 FusedOperator fused_operator (param);
427454
428455 static_assert (Traits::PERSISTENT_MODE || Traits::MATH_WARPGROUPS == 1 );
@@ -446,6 +473,9 @@ struct WarpSpecializedLayerNorm
446473 {
447474 m_base = block_id;
448475 }
476+ const uint32_t eff_m_block
477+ = std::min (static_cast <uint32_t >(Traits::M_BLOCK), static_cast <uint32_t >(param.m - m_base));
478+
449479 // if (blockIdx.x == 0 && thread_id == 0) printf("MATH got tile %d.\n", m_base);
450480
451481 // Peek for data ready.
@@ -613,11 +643,11 @@ struct WarpSpecializedLayerNorm
613643 {
614644 mean[m_offset] /= param.n ;
615645 variance[m_offset] = rsqrtf (variance[m_offset] / param.n - mean[m_offset] * mean[m_offset]
616- + (Traits::AccumulatorType)(1e-5 ));
646+ + (Traits::AccumulatorType)(param. layernorm_eps ));
617647 }
618648 else
619649 {
620- variance[m_offset] = rsqrtf (variance[m_offset] / param.n + (Traits::AccumulatorType)(1e-5 ));
650+ variance[m_offset] = rsqrtf (variance[m_offset] / param.n + (Traits::AccumulatorType)(param. layernorm_eps ));
621651 }
622652 }
623653
@@ -659,8 +689,7 @@ struct WarpSpecializedLayerNorm
659689 }
660690 }
661691
662- #pragma unroll Traits::M_BLOCK
663- for (int m_offset = 0 ; m_offset < Traits::M_BLOCK; m_offset++)
692+ for (int m_offset = 0 ; m_offset < eff_m_block; m_offset++)
664693 {
665694 auto m = m_base + m_offset;
666695
@@ -801,23 +830,22 @@ struct WarpSpecializedLayerNorm
801830 shared->init (threadIdx .x == 0 );
802831
803832 __syncthreads ();
804- #if (defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 12))
805- #if (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM100_ALL))
833+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)
806834 if constexpr (arch::is_major_v<9 > || arch::is_major_v<10 >)
807835 {
808836 auto block_id = blockIdx .x ;
809837 auto warp_id = threadIdx .x / 32 ;
810838 auto lane_id = threadIdx .x % 32 ;
811839 auto tid_in_wg = threadIdx .x % 128 ;
812-
840+ // cudaGridDependencySynchronize();
813841 if (warp_id < 4 )
814842 {
815843 asm volatile (" {setmaxnreg.dec.sync.aligned.u32 56; \n\t }" );
816844 if (warp_id == 0 )
817845 {
818846 scheduler (lane_id, gridDim .x * gridDim .y * gridDim .z , param, shared);
819- // PRE-EXIT after all tiles have been scheduled.
820- cudaTriggerProgrammaticLaunchCompletion ();
847+ // PRE-EXIT after all tiles have been scheduled.
848+ // cudaTriggerProgrammaticLaunchCompletion();
821849 }
822850 else if (warp_id == 1 )
823851 {
@@ -829,8 +857,11 @@ struct WarpSpecializedLayerNorm
829857 asm volatile (" {setmaxnreg.inc.sync.aligned.u32 224; \n\t }" );
830858 compute (block_id, threadIdx .x / 128 - 1 , tid_in_wg, param, shared);
831859 }
860+ __syncthreads ();
861+ asm volatile (" membar.gl;" : : : " memory" );
862+ cudaTriggerProgrammaticLaunchCompletion ();
863+ // cudaTriggerProgrammaticLaunchCompletion();
832864 }
833- #endif
834865#endif
835866 }
836867};
0 commit comments