Skip to content

Commit 0559d9a

Browse files
authored
[hopper][WS] Automatic data partition for multiple consumers (#6791)
When multiple consumers are requested to execute the same code region, the compiler determines how to divide the work between them. On Hopper, the compiler will, by default, attempt to split the input tensor A for a dot operation along the M dimension so that each consumer computes half of the output tensor independently. This approach is aka cooperative partitioning. If this split is not advantageous—for instance, if it results in a smaller-than-native wgmma instruction—the compiler will instead attempt to split along the N dimension. The transformed code for a typical GEMM kernel with a configured tile size [128, 256, 64] will look like below (using source annotations instead of IR for illustration) ``` @triton.jit def matmul_persistent_ws_kernel( a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ): pid = tl.program_id(axis=0) # async_task 0, 1, 2 num_pid_m = tl.cdiv(M, BLOCK_M) # async_task 0, 1, 2 num_pid_n = tl.cdiv(N, BLOCK_N) # async_task 0, 1, 2 pid_m = pid // num_pid_m # async_task 0, 1, 2 pid_n = pid % num_pid_n # async_task 0, 1, 2 offs_m_1 = pid_m * BLOCK_M + tl.arange(0, BLOCK_M // 2) # async_task 0, 1, 2 offs_m_2 = pid_m * BLOCK_M + tl.arange(BLOCK_M // 2, BLOCK_M) # async_task 0, 1, 2 offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_N) # async_task 0, 1, 2 offs_k = tl.arange(0, BLOCK_K) # async_task 0 a_ptrs_1 = a_ptr + (offs_m_1[:, None] * stride_am + offs_k[None, :] * stride_ak) # async_task 0 a_ptrs_2 = a_ptr + (offs_m_2[:, None] * stride_am + offs_k[None, :] * stride_ak) # async_task 0 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn) # async_task 0 acc_1 = tl.zeros((BLOCK_M // 2, BLOCK_N), dtype=tl.float32) # async_task 1 acc_1 = tl.zeros((BLOCK_M // 2, BLOCK_N), dtype=tl.float32) # async_task 2 for k in range(0, tl.cdiv(K, BLOCK_K)): # async_task 0, 1, 2 a_1 = tl.load(a_ptrs_1) # async_task 0 a_2 = tl.load(a_ptrs_2) # async_task 0 b = tl.load(b_ptrs) # async_task 0 acc_1 += tl.dot(a_1, b) # async_task 1 acc_2 += tl.dot(a_2, b) # async_task 2 a_ptrs_1 += BLOCK_K * stride_ak # async_task 0 a_ptrs_2 += BLOCK_K * stride_ak # async_task 0 b_ptrs += BLOCK_K * stride_bk # async_task 0 c_1 = acc_1.to(tl.float16) # async_task 1 c_2 = acc_2.to(tl.float16) # async_task 2 c_ptrs_1 = c_ptr_1 + stride_cm * offs_m_1[:, None] + stride_cn * offs_n[None, :] # async_task 1 c_ptrs_2 = c_ptr_2 + stride_cm * offs_m_2[:, None] + stride_cn * offs_n[None, :] # async_task 2 tl.store(c_ptrs_1, c_1) # async_task 1 tl.store(c_ptrs_2, c_2) # async_task 2 ```
1 parent aeb4d4f commit 0559d9a

File tree

6 files changed

+1543
-1
lines changed

6 files changed

+1543
-1
lines changed
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
// RUN: triton-opt %s -split-input-file --nvgpu-test-ws-data-partition=num-warp-groups=3 | FileCheck %s
2+
3+
// CHECK-LABEL: @matmul_persistent_ws_cooperative_kernel
4+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
5+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
6+
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 16]}>
7+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 0}>
8+
#smem = #ttg.shared_memory
9+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
10+
tt.func public @matmul_persistent_ws_cooperative_kernel(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
11+
%c0_i32 = arith.constant {async_task_id = array<i32: 0, 1, 2>} 0 : i32
12+
%c1_i32 = arith.constant {async_task_id = array<i32: 0, 1, 2>} 1 : i32
13+
%c64_i32 = arith.constant {async_task_id = array<i32: 0, 1, 2>} 64 : i32
14+
%cst = arith.constant {async_task_id = array<i32: 1, 2>} dense<0.000000e+00> : tensor<128x256xf32, #mma>
15+
%0 = tt.get_program_id x {async_task_id = array<i32: 0, 1, 2>} : i32
16+
%1 = tt.get_num_programs x {async_task_id = array<i32: 0, 1, 2>} : i32
17+
scf.for %arg6 = %0 to %arg3 step %1 : i32 {
18+
%2 = tt.splat %arg0 {async_task_id = array<i32: 0>} : !tt.ptr<f16> -> tensor<128x64x!tt.ptr<f16>, #blocked>
19+
%3 = tt.splat %arg1 {async_task_id = array<i32: 0>} : !tt.ptr<f16> -> tensor<64x256x!tt.ptr<f16>, #blocked1>
20+
%4:2 = scf.for %arg7 = %c0_i32 to %arg4 step %c1_i32 iter_args(%arg8 = %cst, %arg9 = %c0_i32) -> (tensor<128x256xf32, #mma>, i32) : i32 {
21+
// CHECK: %[[#GA1:]] = tt.load {{.*}} : tensor<64x64x!tt.ptr<f16>
22+
// CHECK: %[[#GA2:]] = tt.load {{.*}} : tensor<64x64x!tt.ptr<f16>
23+
%8 = tt.load %2 {async_task_id = array<i32: 0>} : tensor<128x64x!tt.ptr<f16>, #blocked>
24+
// CHECK: %[[#LA1:]] = ttg.local_alloc %[[#GA1]]
25+
// CHECK: %[[#LA2:]] = ttg.local_alloc %[[#GA2]]
26+
%9 = ttg.local_alloc %8 {async_task_id = array<i32: 1, 2>} : (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared, #smem>
27+
// CHECK: %[[#GB:]] = tt.load {{.*}} : tensor<64x256x!tt.ptr<f16>
28+
%10 = tt.load %3 {async_task_id = array<i32: 0>} : tensor<64x256x!tt.ptr<f16>, #blocked1>
29+
// CHECK: %[[#LB:]] = ttg.local_alloc %[[#GB]]
30+
%11 = ttg.local_alloc %10 {async_task_id = array<i32: 1, 2>} : (tensor<64x256xf16, #blocked1>) -> !ttg.memdesc<64x256xf16, #shared, #smem>
31+
// CHECK: %[[#C1:]] = ttng.warp_group_dot %[[#LA1]], %[[#LB]], {{.*}} : !ttg.memdesc<64x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared, #smem> -> tensor<64x256xf32, #mma>
32+
// CHECK: %[[#C2:]] = ttng.warp_group_dot %[[#LA2]], %[[#LB]], {{.*}} : !ttg.memdesc<64x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared, #smem> -> tensor<64x256xf32, #mma>
33+
%12 = ttng.warp_group_dot %9, %11, %arg8 {async_task_id = array<i32: 1, 2>, inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared, #smem> -> tensor<128x256xf32, #mma>
34+
%13 = arith.addi %arg9, %c64_i32 {async_task_id = array<i32: 0>} : i32
35+
scf.yield {async_task_id = array<i32: 0, 1, 2>} %12, %13 : tensor<128x256xf32, #mma>, i32
36+
} {async_task_id = array<i32: 0, 1, 2>}
37+
%5 = arith.truncf %4#0 {async_task_id = array<i32: 1, 2>} : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma>
38+
%6 = ttg.convert_layout %5 {async_task_id = array<i32: 1, 2>} : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked1>
39+
%7 = tt.splat %arg2 {async_task_id = array<i32: 1, 2>} : !tt.ptr<f16> -> tensor<128x256x!tt.ptr<f16>, #blocked1>
40+
// CHECK: tt.store {{.*}} : tensor<64x256x!tt.ptr<f16>, #blocked1>
41+
// CHECK: tt.store {{.*}} : tensor<64x256x!tt.ptr<f16>, #blocked1>
42+
tt.store %7, %6 {async_task_id = array<i32: 1, 2>} : tensor<128x256x!tt.ptr<f16>, #blocked1>
43+
}
44+
tt.return
45+
}
46+
}
47+
48+
// -----
49+
50+
// CHECK-LABEL: @cross_dim_partition
51+
#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1]}>
52+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
53+
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}>
54+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 0}>
55+
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 0}>
56+
#smem = #ttg.shared_memory
57+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
58+
tt.func public @cross_dim_partition(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg7: f32, %arg8: i32, %arg9: i32 {tt.divisibility = 16 : i32}, %arg10: i32) attributes {noinline = false} {
59+
%cst = arith.constant {async_task_id = array<i32: 1, 2>} dense<0.000000e+00> : tensor<128x128xf32, #mma>
60+
%cst_0 = arith.constant {async_task_id = array<i32: 1, 2>} dense<true> : tensor<128x128xi1, #blocked>
61+
%c1_i32 = arith.constant {async_task_id = array<i32: 0, 1, 2>} 1 : i32
62+
%c128_i32 = arith.constant {async_task_id = array<i32: 0, 1, 2>} 128 : i32
63+
%c64_i32 = arith.constant {async_task_id = array<i32: 0>} 64 : i32
64+
%0 = tt.get_program_id x {async_task_id = array<i32: 0, 1, 2>} : i32
65+
%1 = tt.get_program_id y {async_task_id = array<i32: 0, 1, 2>} : i32
66+
%2 = tt.load %arg1 {async_task_id = array<i32: 0, 1, 2>} : !tt.ptr<i32>
67+
%3 = arith.extsi %arg8 {async_task_id = array<i32: 0>} : i32 to i64
68+
tt.experimental_tensormap_create %arg6, %arg0, [%c64_i32, %c64_i32], [%arg8, %2], [%3], [%c1_i32, %c1_i32] {async_task_id = array<i32: 0>, elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr<i8>, !tt.ptr<bf16>, i32, i32, i32, i32, i64, i32, i32) -> ()
69+
tt.experimental_tensormap_create %arg6, %arg2, [%c64_i32, %c128_i32], [%arg8, %arg9], [%3], [%c1_i32, %c1_i32] {async_task_id = array<i32: 0>, elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr<i8>, !tt.ptr<bf16>, i32, i32, i32, i32, i64, i32, i32) -> ()
70+
tt.experimental_tensormap_create %arg6, %arg3, [%c64_i32, %c64_i32], [%arg8, %2], [%3], [%c1_i32, %c1_i32] {async_task_id = array<i32: 0>, elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr<i8>, !tt.ptr<bf16>, i32, i32, i32, i32, i64, i32, i32) -> ()
71+
tt.experimental_tensormap_create %arg6, %arg5, [%c64_i32, %c64_i32], [%arg8, %2], [%3], [%c1_i32, %c1_i32] {async_task_id = array<i32: 0>, elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr<i8>, !tt.ptr<bf16>, i32, i32, i32, i32, i64, i32, i32) -> ()
72+
%4 = tt.reinterpret_tensor_descriptor %arg6 {async_task_id = array<i32: 0>} : !tt.ptr<i8> to !tt.tensordesc<tensor<128x128xbf16>>
73+
%5 = tt.reinterpret_tensor_descriptor %arg6 {async_task_id = array<i32: 0>} : !tt.ptr<i8> to !tt.tensordesc<tensor<128x128xbf16>>
74+
%6 = tt.reinterpret_tensor_descriptor %arg6 {async_task_id = array<i32: 0>} : !tt.ptr<i8> to !tt.tensordesc<tensor<128x128xbf16>>
75+
%7 = tt.reinterpret_tensor_descriptor %arg6 {async_task_id = array<i32: 0>} : !tt.ptr<i8> to !tt.tensordesc<tensor<128x128xbf16>>
76+
// CHECK: tt.descriptor_load {{.*}} -> tensor<64x128xbf16
77+
// CHECK: tt.descriptor_load {{.*}} -> tensor<64x128xbf16
78+
%8 = tt.descriptor_load %4[%0, %1] {async_task_id = array<i32: 0>} : !tt.tensordesc<tensor<128x128xbf16>> -> tensor<128x128xbf16, #blocked1>
79+
%9 = ttg.local_alloc %8 {async_task_id = array<i32: 1, 2>} : (tensor<128x128xbf16, #blocked1>) -> !ttg.memdesc<128x128xbf16, #shared, #smem>
80+
// CHECK: tt.descriptor_load {{.*}} -> tensor<128x128xbf16
81+
%10 = tt.descriptor_load %5[%1, %1] {async_task_id = array<i32: 0>} : !tt.tensordesc<tensor<128x128xbf16>> -> tensor<128x128xbf16, #blocked1>
82+
%11 = ttg.local_alloc %10 {async_task_id = array<i32: 1, 2>} : (tensor<128x128xbf16, #blocked1>) -> !ttg.memdesc<128x128xbf16, #shared, #smem>
83+
// CHECK: ttng.warp_group_dot {{.*}} : !ttg.memdesc<64x128xbf16, {{.*}} * !ttg.memdesc<128x128xbf16, {{.*}} -> tensor<64x128xf32, {{.*}}
84+
// CHECK: ttng.warp_group_dot {{.*}} : !ttg.memdesc<64x128xbf16, {{.*}} * !ttg.memdesc<128x128xbf16, {{.*}} -> tensor<64x128xf32, {{.*}}
85+
%12 = ttng.warp_group_dot %9, %11, %cst {async_task_id = array<i32: 1, 2>, inputPrecision = 0 : i32} : !ttg.memdesc<128x128xbf16, #shared, #smem> * !ttg.memdesc<128x128xbf16, #shared, #smem> -> tensor<128x128xf32, #mma>
86+
%13 = arith.truncf %12 {async_task_id = array<i32: 1, 2>} : tensor<128x128xf32, #mma> to tensor<128x128xbf16, #mma>
87+
%14 = ttg.local_alloc %13 {async_task_id = array<i32: 1, 2>} : (tensor<128x128xbf16, #mma>) -> !ttg.memdesc<128x128xbf16, #shared, #smem>
88+
// CHECK: tt.descriptor_load {{.*}} -> tensor<64x128xbf16
89+
// CHECK: tt.descriptor_load {{.*}} -> tensor<64x128xbf16
90+
%15 = tt.descriptor_load %6[%0, %1] {async_task_id = array<i32: 0>} : !tt.tensordesc<tensor<128x128xbf16>> -> tensor<128x128xbf16, #blocked1>
91+
%16 = ttg.local_alloc %15 {async_task_id = array<i32: 1, 2>} : (tensor<128x128xbf16, #blocked1>) -> !ttg.memdesc<128x128xbf16, #shared, #smem>
92+
%17 = ttg.memdesc_trans %16 {async_task_id = array<i32: 1, 2>, order = array<i32: 1, 0>} : !ttg.memdesc<128x128xbf16, #shared, #smem> -> !ttg.memdesc<128x128xbf16, #shared1, #smem>
93+
// CHECK: ttng.warp_group_dot {{.*}} : !ttg.memdesc<128x64xbf16, {{.*}} * !ttg.memdesc<64x128xbf16, {{.*}} -> tensor<128x128xf32, {{.*}}
94+
// CHECK: ttng.warp_group_dot {{.*}} : !ttg.memdesc<128x64xbf16, {{.*}} * !ttg.memdesc<64x128xbf16, {{.*}} -> tensor<128x128xf32, {{.*}}
95+
%18 = ttng.warp_group_dot %17, %14, %cst {async_task_id = array<i32: 1, 2>, inputPrecision = 0 : i32} : !ttg.memdesc<128x128xbf16, #shared1, #smem> * !ttg.memdesc<128x128xbf16, #shared, #smem> -> tensor<128x128xf32, #mma>
96+
%19 = ttg.convert_layout %18 {async_task_id = array<i32: 1, 2>} : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked>
97+
%20 = arith.truncf %19 {async_task_id = array<i32: 1, 2>} : tensor<128x128xf32, #blocked> to tensor<128x128xbf16, #blocked>
98+
%21 = tt.splat %arg4 {async_task_id = array<i32: 1, 2>} : !tt.ptr<bf16> -> tensor<1x128x!tt.ptr<bf16>, #blocked>
99+
%22 = tt.broadcast %21 {async_task_id = array<i32: 1, 2>} : tensor<1x128x!tt.ptr<bf16>, #blocked> -> tensor<128x128x!tt.ptr<bf16>, #blocked>
100+
%23 = tt.atomic_rmw fadd, relaxed, gpu, %22, %20, %cst_0 {async_task_id = array<i32: 1, 2>} : (tensor<128x128x!tt.ptr<bf16>, #blocked>, tensor<128x128xbf16, #blocked>, tensor<128x128xi1, #blocked>) -> tensor<128x128xbf16, #blocked>
101+
tt.return
102+
}
103+
}

third_party/nvidia/hopper/include/Transforms/Passes.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,17 @@ def NVGPUTestWSTaskPartition : Pass<"nvgpu-test-ws-task-partition", "mlir::Modul
3333
];
3434
}
3535

36+
def NVGPUTestWSDataPartition : Pass<"nvgpu-test-ws-data-partition", "mlir::ModuleOp"> {
37+
let summary = "test warp specialization data partition";
38+
39+
let description = "This pass partitions operations into multiple suboperations which operate on smaller data shapes";
40+
41+
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
42+
let options = [
43+
Option<"numWarpGroups", "num-warp-groups",
44+
"int32_t", /*default*/"0",
45+
"number of warp groups for warp specialization">
46+
];
47+
}
48+
3649
#endif // NV_TRANSFORMS_PASSES

third_party/nvidia/hopper/lib/Transforms/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
add_triton_library(NVHopperTransforms
22
WarpSpecialization.cpp
3-
WarpSpecialization/WSTaskPartition.cpp
43
WarpSpecialization/Utility.cpp
4+
WarpSpecialization/WSDataPartition.cpp
5+
WarpSpecialization/WSTaskPartition.cpp
56

67
DEPENDS
78
NVHopperTransformsIncGen

third_party/nvidia/hopper/lib/Transforms/WarpSpecialization.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
namespace mlir {
1212

1313
void doTaskPartition(triton::FuncOp &funcOp, unsigned numWarpGroups);
14+
bool doDataPartition(triton::FuncOp &funcOp, unsigned numConsumerGroups);
1415

1516
#define GEN_PASS_DEF_NVGPUWARPSPECIALIZATION
1617
#include "nvidia/hopper/include/Transforms/Passes.h.inc"
@@ -27,6 +28,10 @@ class NVGPUWarpSpecializationPass
2728

2829
// Partition key ops into multiple async tasks.
2930
doTaskPartition(funcOp, numWarpGroups);
31+
32+
// Partition ops into parallel sub ops.
33+
if (!doDataPartition(funcOp, numWarpGroups - 1))
34+
signalPassFailure();
3035
}
3136

3237
void runOnOperation() override {

third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/Utility.h

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11

22
#ifndef NV_DIALECT_HOPPER_TRANSFORMS_UTILITY_H_
33

4+
#include "mlir/IR/Builders.h"
45
#include "mlir/IR/BuiltinTypes.h"
56
#include "mlir/IR/Dialect.h"
67
#include "mlir/IR/Operation.h"
8+
#include "llvm/ADT/SetVector.h"
79
#include "llvm/ADT/SmallVector.h"
810

911
namespace mlir {
@@ -32,5 +34,46 @@ void removeAsyncTaskId(Operation *op, AsyncTaskId asyncTaskId);
3234
// Removes all async task ids from the given operation.
3335
void removeAsyncTaskIds(Operation *op);
3436

37+
class OpBuilderWithAsyncTaskIds : public OpBuilder {
38+
public:
39+
OpBuilderWithAsyncTaskIds(MLIRContext *context) : OpBuilder(context) {}
40+
41+
explicit OpBuilderWithAsyncTaskIds(Operation *op) : OpBuilder(op) {
42+
setAsyncTaskIdsFromOp(op);
43+
}
44+
45+
void setAsynTaskIdsFromArray(ArrayRef<AsyncTaskId> newAsyncTaskIds) {
46+
asyncTaskIds = SmallVector<AsyncTaskId>(newAsyncTaskIds.begin(),
47+
newAsyncTaskIds.end());
48+
}
49+
50+
void setAsyncTaskIdsFromOp(Operation *op) {
51+
setAsynTaskIdsFromArray(getAsyncTaskIds(op));
52+
}
53+
54+
void setAsyncTaskIdsFromValueUsers(Value value) {
55+
SetVector<AsyncTaskId> asyncTaskIdSet;
56+
for (Operation *user : value.getUsers())
57+
for (AsyncTaskId asyncTaskId : getAsyncTaskIds(user))
58+
asyncTaskIdSet.insert(asyncTaskId);
59+
setAsynTaskIdsFromArray(asyncTaskIdSet.getArrayRef());
60+
}
61+
62+
template <typename OpTy, typename... Args>
63+
OpTy createWithAsyncTaskIds(Args &&...args) {
64+
OpTy op = OpBuilder::create<OpTy>(std::forward<Args>(args)...);
65+
if (!asyncTaskIds.empty())
66+
setAsyncTaskIds(op, asyncTaskIds);
67+
return op;
68+
}
69+
70+
template <typename OpTy, typename... Args> OpTy create(Args &&...args) {
71+
OpTy op = createWithAsyncTaskIds<OpTy>(std::forward<Args>(args)...);
72+
return op;
73+
}
74+
75+
private:
76+
SmallVector<AsyncTaskId> asyncTaskIds;
77+
};
3578
} // namespace mlir
3679
#endif // NV_DIALECT_HOPPER_TRANSFORMS_UTILITY_H_

0 commit comments

Comments
 (0)