Skip to content

Commit 4df01dc

Browse files
committed
[mlir][sparse][gpu][nvidia] add pruning step and check to 2:4 matrix multiplication
(1) without the check, the results may silently be wrong, so check is needed (2) add pruning step to guarantee 2:4 property Note, in the longer run, we may want to split out the pruning step somehow, or make it optional. Reviewed By: K-Wu Differential Revision: https://reviews.llvm.org/D155320
1 parent 739164c commit 4df01dc

File tree

2 files changed

+152
-1
lines changed

2 files changed

+152
-1
lines changed

mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ mgpuDestroyCuSparseLtSpMat(void *sh, CUstream /*stream*/) {
567567
// and returning workspace and compressed matrices data buffer sizes.
568568
extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
569569
mgpuCuSparseLtSpMMBufferSize(void *bs, int32_t ma, int32_t mb, void *a, void *b,
570-
void *c, int32_t ctp, CUstream /*stream*/) {
570+
void *c, int32_t ctp, CUstream stream) {
571571
assert(cusparseLt_initiated && "client did not call mgpuCreateSparseLtEnv()");
572572
// TODO: support more advanced settings, e.g., the input right operand is a
573573
// sparse matrix assuming matA is the sparse matrix
@@ -596,6 +596,25 @@ mgpuCuSparseLtSpMMBufferSize(void *bs, int32_t ma, int32_t mb, void *a, void *b,
596596
CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulPlanInit(
597597
&cusparseLt_env, &(matA->plan), &(matA->matmul), &(matA->alg_sel)))
598598

599+
// Pruning step (in-place).
600+
CUSPARSE_REPORT_IF_ERROR(
601+
cusparseLtSpMMAPrune(&cusparseLt_env, &(matA->matmul), matA->values,
602+
matA->values, CUSPARSELT_PRUNE_SPMMA_STRIP, stream))
603+
604+
// Check structure of A.
605+
// Note that this adds a synchronization on the stream.
606+
// TODO: Do we want that?
607+
int *dvalid = (int *)mgpuMemAlloc(sizeof(int), stream);
608+
CUSPARSE_REPORT_IF_ERROR(cusparseLtSpMMAPruneCheck(
609+
&cusparseLt_env, &(matA->matmul), matA->values, dvalid, stream))
610+
int valid = 0;
611+
mgpuMemcpy(&valid, dvalid, sizeof(int), stream);
612+
mgpuStreamSynchronize(stream);
613+
mgpuMemFree(dvalid, stream);
614+
if (valid != 0)
615+
fprintf(stderr, "CUPARSE-LT: sparse matrix is not 2:4; computed results "
616+
"will be invalid\n");
617+
599618
CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulGetWorkspace(
600619
&cusparseLt_env, &(matA->plan), &workspace_size_))
601620
CUSPARSE_REPORT_IF_ERROR(cusparseLtSpMMACompressedSize(
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
//
2+
// NOTE: this test requires gpu-sm80 and cusparselt
3+
//
4+
// RUN: mlir-opt --sparse-compiler="enable-runtime-library=false enable-gpu-libgen=true gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71" \
5+
// RUN: %s \
6+
// RUN: | mlir-cpu-runner \
7+
// RUN: --shared-libs=%mlir_cuda_runtime \
8+
// RUN: --shared-libs=%mlir_c_runner_utils \
9+
// RUN: --e main --entry-point-result=void \
10+
// RUN: | FileCheck %s
11+
12+
#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
13+
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
14+
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
15+
16+
module {
17+
18+
llvm.func @mgpuCreateSparseLtEnv()
19+
llvm.func @mgpuDestroySparseLtEnv()
20+
21+
//
22+
// TODO: This uses our temporary ATTRIBUTE, replace with 2:4 type!
23+
//
24+
func.func @matmul(%arg0: tensor<16x16xf16>,
25+
%arg1: tensor<16x16xf16>,
26+
%arg2: tensor<16x16xf16>) -> tensor<16x16xf16> {
27+
%0 = linalg.generic {
28+
DENSE24,
29+
indexing_maps = [#map0, #map1, #map2],
30+
iterator_types = ["parallel", "parallel", "reduction"]
31+
}
32+
ins(%arg0, %arg1 : tensor<16x16xf16>, tensor<16x16xf16>)
33+
outs(%arg2 : tensor<16x16xf16>) {
34+
^bb0(%in: f16, %in_0: f16, %out: f16):
35+
%1 = arith.mulf %in, %in_0 : f16
36+
%2 = arith.addf %out, %1 : f16
37+
linalg.yield %2 : f16
38+
} -> tensor<16x16xf16>
39+
return %0 : tensor<16x16xf16>
40+
}
41+
42+
func.func @main() {
43+
llvm.call @mgpuCreateSparseLtEnv() : () -> ()
44+
45+
%c0 = arith.constant 0 : index
46+
%c1 = arith.constant 1 : index
47+
%c16 = arith.constant 16 : index
48+
49+
%f0 = arith.constant 0.0 : f16
50+
%f1 = arith.constant 1.0 : f16
51+
%f4 = arith.constant 4.0 : f16
52+
53+
// Initial A, B, C matrices.
54+
%A = tensor.generate {
55+
^bb0(%i: index, %j: index):
56+
%val = arith.andi %j, %c1 : index
57+
%cmp = arith.cmpi eq, %val, %c0 : index
58+
%res = arith.select %cmp, %f4, %f1 : f16
59+
tensor.yield %res : f16
60+
} : tensor<16x16xf16>
61+
%B = tensor.generate {
62+
^bb0(%i: index, %j: index):
63+
%cmp = arith.cmpi eq, %i, %j : index
64+
%res = arith.select %cmp, %f1, %f0 : f16
65+
tensor.yield %res : f16
66+
} : tensor<16x16xf16>
67+
%C = tensor.generate {
68+
^bb0(%i: index, %j: index):
69+
tensor.yield %f0 : f16
70+
} : tensor<16x16xf16>
71+
72+
// Call the kernel.
73+
//
74+
// By effectively computing D = A B + C with id(B) and zero(C)
75+
// the resulting matrix returns the pruned A back to the caller.
76+
//
77+
%D = call @matmul(%A, %B, %C): (tensor<16x16xf16>, tensor<16x16xf16>, tensor<16x16xf16>) -> (tensor<16x16xf16>)
78+
79+
//
80+
// This was the original matrix.
81+
//
82+
// CHECK: ( 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1 )
83+
// CHECK-NEXT: ( 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1 )
84+
// CHECK-NEXT: ( 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1 )
85+
// CHECK-NEXT: ( 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1 )
86+
// CHECK-NEXT: ( 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1 )
87+
// CHECK-NEXT: ( 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1 )
88+
// CHECK-NEXT: ( 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1 )
89+
// CHECK-NEXT: ( 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1 )
90+
// CHECK-NEXT: ( 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1 )
91+
// CHECK-NEXT: ( 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1 )
92+
// CHECK-NEXT: ( 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1 )
93+
// CHECK-NEXT: ( 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1 )
94+
// CHECK-NEXT: ( 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1 )
95+
// CHECK-NEXT: ( 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1 )
96+
// CHECK-NEXT: ( 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1 )
97+
// CHECK-NEXT: ( 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1 )
98+
//
99+
scf.for %i = %c0 to %c16 step %c1 {
100+
%va = vector.transfer_read %A[%i, %c0], %f0 : tensor<16x16xf16>, vector<16xf16>
101+
vector.print %va : vector<16xf16>
102+
}
103+
104+
//
105+
// This is the STRIP-pruned matrix.
106+
//
107+
// CHECK-NEXT: ( 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0 )
108+
// CHECK-NEXT: ( 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0 )
109+
// CHECK-NEXT: ( 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0 )
110+
// CHECK-NEXT: ( 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0 )
111+
// CHECK-NEXT: ( 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0 )
112+
// CHECK-NEXT: ( 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0 )
113+
// CHECK-NEXT: ( 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0 )
114+
// CHECK-NEXT: ( 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0 )
115+
// CHECK-NEXT: ( 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0 )
116+
// CHECK-NEXT: ( 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0 )
117+
// CHECK-NEXT: ( 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0 )
118+
// CHECK-NEXT: ( 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0 )
119+
// CHECK-NEXT: ( 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0 )
120+
// CHECK-NEXT: ( 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0 )
121+
// CHECK-NEXT: ( 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0 )
122+
// CHECK-NEXT: ( 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0 )
123+
//
124+
scf.for %i = %c0 to %c16 step %c1 {
125+
%vd = vector.transfer_read %D[%i, %c0], %f0 : tensor<16x16xf16>, vector<16xf16>
126+
vector.print %vd : vector<16xf16>
127+
}
128+
129+
llvm.call @mgpuDestroySparseLtEnv() : () -> ()
130+
return
131+
}
132+
}

0 commit comments

Comments
 (0)