Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/kernels/internode_ll.cu
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,7 @@ __global__ __launch_bounds__(1024, 1) void combine(void* combined_x,
const int& next_stage_idx = (iter_idx + 1) % kNumStages;
if (iter_idx + 1 < kNumIters and elect_one_sync()) {
tma_store_wait<kNumStages - kNumPrefetch - 1>();
const auto& offset_int4 = i + 32 * kNumSendUnrolls;
const auto& offset_int4 = (iter_idx + 1) * 32 * kNumSendUnrolls;
tma_load_and_arrive(next_stage_idx, cpy_src_int4_ptr + offset_int4, get_num_tma_bytes(offset_int4));
}
__syncwarp();
Expand Down
13 changes: 9 additions & 4 deletions csrc/kernels/intranode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -983,8 +983,9 @@ __global__ void __launch_bounds__(kNumThreads, 1) combine(dtype_t* recv_x,
#ifndef DISABLE_SM90_FEATURES
if (i < hidden_int4_aligned) {
// Wait TMA arrival
tma_store_wait<kNumStages - 1>();
__syncwarp();
if (lane_id < kNumStages) {
tma_store_wait<0>();
}

// Write into TMA buffer
auto tma_stage_idx = (i / 32) % kNumStages;
Expand All @@ -993,7 +994,7 @@ __global__ void __launch_bounds__(kNumThreads, 1) combine(dtype_t* recv_x,
// Issue TMA
tma_store_fence();
__syncwarp();
if (elect_one_sync()) {
if (lane_id < kNumStages) {
auto tma_bytes = min(32, hidden_int4 - i) * static_cast<int>(sizeof(int4));
tma_store_1d(reinterpret_cast<int4*>(tma_buffer) + tma_stage_idx * 32,
recv_int4 + token_idx * hidden_int4 + i,
Expand All @@ -1005,7 +1006,11 @@ __global__ void __launch_bounds__(kNumThreads, 1) combine(dtype_t* recv_x,
#endif
recv_int4[token_idx * hidden_int4 + i] = out_int4;
#ifndef DISABLE_SM90_FEATURES
}
}
}
// Flush all stores
tma_store_wait<0>();
__syncwarp();
#endif
}

Expand Down