|
48 | 48 |
|
49 | 49 | import cutlass.cute as cute |
50 | 50 | from cutlass.cutlass_dsl import Boolean, if_generate |
51 | | -from cutlass.pipeline import (CooperativeGroup, PipelineAsync, PipelineOp, |
52 | | - PipelineState) |
| 51 | +from cutlass.pipeline import (Agent, CooperativeGroup, PipelineAsync, |
| 52 | + PipelineOp, PipelineState, agent_sync) |
53 | 53 |
|
54 | 54 |
|
55 | 55 | def pipeline_init_wait(cta_layout_vmnk: Optional[cute.Layout] = None): |
@@ -374,3 +374,153 @@ def then_body(): |
374 | 374 | self.producer_acquire(state) |
375 | 375 |
|
376 | 376 | if_generate(is_leader_cta, then_body) |
| 377 | + |
| 378 | + |
| 379 | +@dataclass(frozen=True) |
| 380 | +class PipelineCpAsyncUmma(PipelineAsync): |
| 381 | + """ |
| 382 | + PipelineCpAsyncUmma is used for LDGSTS (CpAsync) producers and UMMA consumers. |
| 383 | +
|
| 384 | + This pipeline is specifically designed for scenarios where: |
| 385 | + - Producers use LDGSTS instructions (cp.async) to load data from global to shared memory |
| 386 | + - Consumers are UMMA warps that perform MMA operations using the loaded data |
| 387 | +
|
| 388 | + Key differences from PipelineAsyncUmma: |
| 389 | + - Suitable for gather/permutation operations during load |
| 390 | + - Used in this kernel for A and SFA matrices with token-based gather addressing |
| 391 | + """ |
| 392 | + |
| 393 | + cta_group: cute.nvgpu.tcgen05.CtaGroup |
| 394 | + |
| 395 | + @staticmethod |
| 396 | + def _compute_leading_cta_rank(cta_v_size): |
| 397 | + """ |
| 398 | + Computes the leading CTA rank. |
| 399 | + """ |
| 400 | + cta_rank_in_cluster = cute.arch.make_warp_uniform( |
| 401 | + cute.arch.block_idx_in_cluster()) |
| 402 | + return cta_rank_in_cluster // cta_v_size * cta_v_size |
| 403 | + |
| 404 | + @staticmethod |
| 405 | + def _compute_is_leader_cta(cta_layout_vmnk: cute.Layout): |
| 406 | + """ |
| 407 | + Computes leader threadblocks for 2CTA kernels. For 1CTA, all threadblocks are leaders. |
| 408 | + """ |
| 409 | + bidx, bidy, _ = cute.arch.block_idx() |
| 410 | + mma_coord_vmnk = ( |
| 411 | + bidx % cute.size(cta_layout_vmnk, mode=[0]), |
| 412 | + bidx // cute.size(cta_layout_vmnk, mode=[0]), |
| 413 | + bidy, |
| 414 | + None, |
| 415 | + ) |
| 416 | + return mma_coord_vmnk[0] == 0 |
| 417 | + |
| 418 | + @staticmethod |
| 419 | + def _compute_peer_cta_mask(cta_layout_vmnk: cute.Layout): |
| 420 | + """ |
| 421 | + Computes a mask for signaling arrivals to multicasting threadblocks. |
| 422 | + """ |
| 423 | + cta_rank_in_cluster = cute.arch.make_warp_uniform( |
| 424 | + cute.arch.block_idx_in_cluster()) |
| 425 | + cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord( |
| 426 | + cta_rank_in_cluster) |
| 427 | + mask_self = cute.nvgpu.cpasync.create_tma_multicast_mask( |
| 428 | + cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=0) |
| 429 | + block_in_cluster_coord_vmnk_peer = ( |
| 430 | + cta_in_cluster_coord_vmnk[0] ^ 1, |
| 431 | + *cta_in_cluster_coord_vmnk[1:], |
| 432 | + ) |
| 433 | + mask_peer = cute.nvgpu.cpasync.create_tma_multicast_mask( |
| 434 | + cta_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=0) |
| 435 | + return mask_self | mask_peer |
| 436 | + |
| 437 | + @staticmethod |
| 438 | + def create( |
| 439 | + *, |
| 440 | + num_stages: int, |
| 441 | + producer_group: CooperativeGroup, |
| 442 | + consumer_group: CooperativeGroup, |
| 443 | + barrier_storage: cute.Pointer = None, |
| 444 | + cta_layout_vmnk: Optional[cute.Layout] = None, |
| 445 | + defer_sync: bool = False, |
| 446 | + enable_cp_async: bool = False, |
| 447 | + ): |
| 448 | + """Creates and initializes a new PipelineCpAsyncUmma instance. |
| 449 | +
|
| 450 | + :param num_stages: Number of buffer stages for this pipeline |
| 451 | + :type num_stages: int |
| 452 | + :param producer_group: CooperativeGroup for the producer agent |
| 453 | + :type producer_group: CooperativeGroup |
| 454 | + :param consumer_group: CooperativeGroup for the consumer agent |
| 455 | + :type consumer_group: CooperativeGroup |
| 456 | + :param barrier_storage: Pointer to the shared memory address for this pipeline's mbarriers |
| 457 | + :type barrier_storage: cute.Pointer, optional |
| 458 | + :param cta_layout_vmnk: Layout of the cluster shape |
| 459 | + :type cta_layout_vmnk: cute.Layout, optional |
| 460 | + :param defer_sync: Whether to defer the sync |
| 461 | + :type defer_sync: bool, optional |
| 462 | + :param enable_cp_async: Whether to enable cp.async instructions |
| 463 | + :type enable_cp_async: bool, optional |
| 464 | + :raises ValueError: If barrier_storage is not a cute.Pointer instance |
| 465 | + :return: A new PipelineCpAsyncUmma instance configured with the provided parameters |
| 466 | + :rtype: PipelineCpAsyncUmma |
| 467 | + """ |
| 468 | + if not isinstance(barrier_storage, cute.Pointer): |
| 469 | + raise ValueError( |
| 470 | + f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" |
| 471 | + ) |
| 472 | + |
| 473 | + producer_type = PipelineOp.AsyncLoad if enable_cp_async else PipelineOp.AsyncThread |
| 474 | + consumer_type = PipelineOp.TCGen05Mma |
| 475 | + |
| 476 | + producer = (producer_type, producer_group) |
| 477 | + consumer = (consumer_type, consumer_group) |
| 478 | + |
| 479 | + sync_object_full = PipelineAsync._make_sync_object( |
| 480 | + barrier_storage.align(min_align=8), |
| 481 | + num_stages, |
| 482 | + producer, |
| 483 | + ) |
| 484 | + sync_object_empty = PipelineAsync._make_sync_object( |
| 485 | + barrier_storage.align(min_align=8) + num_stages, num_stages, |
| 486 | + consumer) |
| 487 | + |
| 488 | + cta_v_size = cute.size(cta_layout_vmnk, |
| 489 | + mode=[0]) if cta_layout_vmnk is not None else 1 |
| 490 | + cta_group = (cute.nvgpu.tcgen05.CtaGroup.ONE if cta_layout_vmnk is None |
| 491 | + or cute.size(cta_layout_vmnk, mode=[0]) == 1 else |
| 492 | + cute.nvgpu.tcgen05.CtaGroup.TWO) |
| 493 | + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1: |
| 494 | + # No mcast mask if we're not using 2CTA tcgen05 MMA |
| 495 | + producer_mask = None |
| 496 | + consumer_mask = None |
| 497 | + else: |
| 498 | + # If we're using 2CTA UMMAs, producer will arrive the mbar on leading CTA |
| 499 | + # We need to get the target cta_rank |
| 500 | + producer_mask = PipelineCpAsyncUmma._compute_leading_cta_rank( |
| 501 | + cta_v_size) |
| 502 | + # consumer needs to get the mask to signal |
| 503 | + consumer_mask = PipelineCpAsyncUmma._compute_peer_cta_mask( |
| 504 | + cta_layout_vmnk) |
| 505 | + |
| 506 | + if not defer_sync: |
| 507 | + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: |
| 508 | + agent_sync(Agent.ThreadBlock) |
| 509 | + else: |
| 510 | + agent_sync(Agent.ThreadBlockCluster, is_relaxed=True) |
| 511 | + |
| 512 | + return PipelineCpAsyncUmma( |
| 513 | + sync_object_full, |
| 514 | + sync_object_empty, |
| 515 | + num_stages, |
| 516 | + producer_mask, |
| 517 | + consumer_mask, |
| 518 | + cta_group, |
| 519 | + ) |
| 520 | + |
| 521 | + def consumer_release(self, state: PipelineState): |
| 522 | + """ |
| 523 | + UMMA consumer release buffer empty, cta_group needs to be provided. |
| 524 | + """ |
| 525 | + self.sync_object_empty.arrive(state.index, self.consumer_mask, |
| 526 | + self.cta_group) |
0 commit comments