Skip to content

Conversation

@cyhdmjzzy
Copy link

  1. Fix the offset_int4 for internode_ll combine

internode_ll.cu:

const auto& offset_int4 = i + 32 * kNumSendUnrolls
            = lane_id * kNumSendUnrolls + iter_idx * 32 * kNumSendUnrolls + 32 * kNumSendUnrolls
            = lane_id * kNumSendUnrolls + (iter_idx + 1) * 32 * kNumSendUnrolls

Wouldn't the lane_id * kNumSendUnrolls part cause a discrepancy?
elect_one_sync() always elects lane 0, so it's equivalent to:
const auto& offset_int4 = (iter_idx + 1) * 32 * kNumSendUnrolls
#525
#359

  1. Fix the MultiStages for intranode combine
    Each invocation of tma_store_1d issues a TMA bulk_group, and multiple TMA bulk_groups issued by the same thread execute serially (see NVIDIA PTX documentation: 9.7.9.27.2.2. Data Movement and Conversion Instructions: cp.async.bulk.wait_group). Because elect_one_sync() always selects lane 0, the tma_store_1d inside if (elect_one_sync()) will be issued serially by the same lane, preventing a true MultiStage pipeline. Assigning the first kNumStages lanes to distinct stages enables genuine MultiStage concurrency.

Furthermore, the inline assembly in tma_store_waitasm volatile("cp.async.bulk.wait_group.read %0;" ::"n"(N) : "memory") — waits for the number of remaining TMA bulk_group read-to-shared-memory transactions initiated by the current executing thread to drop to N. If the first kNumStages lanes each own one stage, then a lane that issues tma_store_1d only needs to wait for the single TMA bulk_group it issued in the previous iteration that writes its corresponding stage in shared memory.

Finally, the __syncwarp() immediately after tma_store_wait() is unnecessary, because tma_store_fence() followed by __syncwarp() is sufficient for visibility.

@cyhdmjzzy cyhdmjzzy marked this pull request as draft December 19, 2025 15:01
@cyhdmjzzy cyhdmjzzy marked this pull request as ready for review December 19, 2025 15:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant