1919#include " gtest/gtest.h"
2020#include < memory>
2121
22- #include < mlir/Dialect/GPU/Transforms/Passes.h>
23-
24- #include " gc/Transforms/Passes.h"
2522#include " mlir/Target/LLVMIR/Export.h"
2623#include " mlir/Target/LLVMIR/ModuleTranslation.h"
2724#include < CL/cl_ext.h>
@@ -31,12 +28,12 @@ using namespace gc::gpu;
3128
3229constexpr char addStatic[] = R"mlir(
3330module @test {
34- func.func @entry(%arg0: memref<32x32xf32 >, %arg1: memref<32x32xf32 >, %arg2: memref<32x32xf32 >) {
35- %0 = bufferization.to_tensor %arg0 restrict : memref<32x32xf32 >
36- %1 = bufferization.to_tensor %arg1 restrict : memref<32x32xf32 >
37- %2 = tensor.empty() : tensor<32x32xf32 >
38- %3 = linalg.add ins(%1, %0 : tensor<32x32xf32 >, tensor<32x32xf32 >) outs(%2 : tensor<32x32xf32 >) -> tensor<32x32xf32 >
39- bufferization.materialize_in_destination %3 in restrict writable %arg2 : (tensor<32x32xf32 >, memref<32x32xf32 >) -> ()
31+ func.func @entry(%arg0: memref<64x64xf32 >, %arg1: memref<64x64xf32 >, %arg2: memref<64x64xf32 >) {
32+ %0 = bufferization.to_tensor %arg0 restrict : memref<64x64xf32 >
33+ %1 = bufferization.to_tensor %arg1 restrict : memref<64x64xf32 >
34+ %2 = tensor.empty() : tensor<64x64xf32 >
35+ %3 = linalg.add ins(%1, %0 : tensor<64x64xf32 >, tensor<64x64xf32 >) outs(%2 : tensor<64x64xf32 >) -> tensor<64x64xf32 >
36+ bufferization.materialize_in_destination %3 in restrict writable %arg2 : (tensor<64x64xf32 >, memref<64x64xf32 >) -> ()
4037 return
4138 }
4239}
@@ -59,40 +56,69 @@ module @test {
5956}
6057)mlir" ;
6158
62- template <unsigned N, unsigned M = N> struct TestAdd {
59+ constexpr char matmulAddStatic[] = R"mlir(
60+ module @fragment_name attributes {"#dlti.sys_spec" = #dlti.target_system_spec<"CPU" : #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 32 : i32>>>} {
61+ func.func @entry(%arg0: memref<64x128xf32>, %arg1: memref<128x128xf32>, %arg2: memref<64x128xf32>) {
62+ %0 = bufferization.to_tensor %arg0 restrict : memref<64x128xf32>
63+ %1 = bufferization.to_tensor %arg1 restrict : memref<128x128xf32>
64+ %2 = tensor.empty() : tensor<64x128xf32>
65+ %cst = arith.constant 0.000000e+00 : f32
66+ %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<64x128xf32>) -> tensor<64x128xf32>
67+ %4 = linalg.matmul_transpose_b ins(%0, %1 : tensor<64x128xf32>, tensor<128x128xf32>) outs(%3 : tensor<64x128xf32>) -> tensor<64x128xf32>
68+ %5 = tensor.empty() : tensor<64x128xf32>
69+ %6 = linalg.add ins(%4, %0 : tensor<64x128xf32>, tensor<64x128xf32>) outs(%5 : tensor<64x128xf32>) -> tensor<64x128xf32>
70+ bufferization.materialize_in_destination %6 in restrict writable %arg2 : (tensor<64x128xf32>, memref<64x128xf32>) -> ()
71+ return
72+ }
73+ }
74+ )mlir" ;
75+
76+ struct TestBase {
6377 OclRuntime runtime = gcGetOrReport(OclRuntime::get());
6478 cl_command_queue queue = gcGetOrReport(runtime.createQueue());
79+ OclContext ctx{runtime, queue};
80+ MLIRContext mlirCtx{gc::initCompilerAndGetDialects ()};
81+
82+ virtual void exec (std::shared_ptr<const OclModule> &mod) = 0;
83+
84+ virtual ~TestBase () { gcGetOrReport (runtime.releaseQueue (queue)); }
85+
86+ OwningOpRef<ModuleOp> parse (const char *code) {
87+ std::unique_ptr<llvm::MemoryBuffer> memBuf =
88+ llvm::MemoryBuffer::getMemBuffer (code);
89+ llvm::SourceMgr srcMgr;
90+ srcMgr.AddNewSourceBuffer (std::move (memBuf), SMLoc ());
91+ return parseSourceFile<ModuleOp>(srcMgr, &mlirCtx);
92+ }
93+ };
6594
95+ template <unsigned N, unsigned M = N> struct TestAdd : TestBase {
6696 static constexpr unsigned size = N * M;
6797 float *buf0 = gcGetOrReport(runtime.usmNewDev<float >(size));
6898 float *buf1 = gcGetOrReport(runtime.usmNewDev<float >(size));
6999 float *buf2 = gcGetOrReport(runtime.usmNewShared<float >(size));
70- MLIRContext mlirCtx{gc::initCompilerAndGetDialects ()};
71- float cpuBuf1[size] = {};
72- float cpuBuf2[size] = {};
73100
74- explicit TestAdd () { std::fill (cpuBuf1, cpuBuf1 + size, 2 .0f ); }
101+ explicit TestAdd () {
102+ float cpuBuf[size];
103+ std::fill (cpuBuf, cpuBuf + size, 2 .0f );
104+ assert (runtime.usmCpy (ctx, cpuBuf, buf0, size));
105+ assert (runtime.usmCpy (ctx, cpuBuf, buf1, size));
106+ gcGetOrReport (ctx.finish ());
107+ }
75108
76- virtual ~TestAdd () {
77- gcGetOrReport (runtime.releaseQueue (queue));
109+ ~TestAdd () override {
78110 assert (runtime.usmFree (buf0));
79111 assert (runtime.usmFree (buf1));
80112 assert (runtime.usmFree (buf2));
81113 }
82114
83- virtual void exec (std::shared_ptr<const OclModule> &mod, OclContext &ctx) = 0;
84-
85115 void test (const char *code) {
86- OclContext ctx (runtime, queue);
87- assert (runtime.usmCpy (ctx, cpuBuf1, buf0, size));
88- assert (runtime.usmCpy (ctx, cpuBuf1, buf1, size));
89-
90116 OclModuleBuilder builder (parse (code));
91117 auto mod = gcGetOrReport (builder.build (runtime));
118+ exec (mod);
92119
93- exec (mod, ctx);
94-
95- assert (runtime.usmCpy (ctx, buf2, cpuBuf2, size));
120+ float cpuBuf[size];
121+ assert (runtime.usmCpy (ctx, buf2, cpuBuf, size));
96122 gcGetOrReport (ctx.finish ());
97123
98124 for (unsigned i = 0 ; i < size; i++) {
@@ -101,24 +127,51 @@ template <unsigned N, unsigned M = N> struct TestAdd {
101127 }
102128 // std::cout << "\n";
103129
104- for (float i : cpuBuf2 ) {
105- // std::cout << cpuBuf2[i] << " ";
130+ for (float i : cpuBuf ) {
131+ // std::cout << i << " ";
106132 assert (i == 4 .0f );
107133 }
108134 }
135+ };
109136
110- OwningOpRef<ModuleOp> parse (const char *code) {
111- std::unique_ptr<llvm::MemoryBuffer> memBuf =
112- llvm::MemoryBuffer::getMemBuffer (code);
113- llvm::SourceMgr srcMgr;
114- srcMgr.AddNewSourceBuffer (std::move (memBuf), SMLoc ());
115- return parseSourceFile<ModuleOp>(srcMgr, &mlirCtx);
137+ template <unsigned N, unsigned M = N> struct TestMatmulAdd : TestBase {
138+ static constexpr unsigned size1 = N * M;
139+ static constexpr unsigned size2 = M * M;
140+ float *buf0 = gcGetOrReport(runtime.usmNewDev<float >(size1));
141+ float *buf1 = gcGetOrReport(runtime.usmNewDev<float >(size2));
142+ float *buf2 = gcGetOrReport(runtime.usmNewShared<float >(size1));
143+
144+ explicit TestMatmulAdd () {
145+ float cpuBuf[size2];
146+ std::fill (cpuBuf, cpuBuf + size2, 2 );
147+ assert (runtime.usmCpy (ctx, cpuBuf, buf0, size1));
148+ assert (runtime.usmCpy (ctx, cpuBuf, buf1, size2));
149+ gcGetOrReport (ctx.finish ());
150+ }
151+
152+ ~TestMatmulAdd () override {
153+ assert (runtime.usmFree (buf0));
154+ assert (runtime.usmFree (buf1));
155+ assert (runtime.usmFree (buf2));
156+ }
157+
158+ void test (const char *code) {
159+ OclModuleBuilder builder (parse (code));
160+ auto mod = gcGetOrReport (builder.build (runtime));
161+ exec (mod);
162+
163+ gcGetOrReport (ctx.finish ());
164+ for (unsigned i = 0 ; i < size1; i++) {
165+ // std::cout << buf2[i] << " ";
166+ assert (buf2[i] == 514 );
167+ }
168+ // std::cout << "\n";
116169 }
117170};
118171
119172TEST (GpuOclRuntime, TestAddStatic) {
120- struct TestAddStatic1 : TestAdd<32 > {
121- void exec (std::shared_ptr<const OclModule> &mod, OclContext &ctx ) override {
173+ struct TestAddStatic1 : TestAdd<64 > {
174+ void exec (std::shared_ptr<const OclModule> &mod) override {
122175 assert (mod->isStatic );
123176 StaticExecutor<3 > exec (mod);
124177 exec (ctx, buf0, buf1, buf2);
@@ -128,8 +181,8 @@ TEST(GpuOclRuntime, TestAddStatic) {
128181 } test1;
129182 test1.test (addStatic);
130183
131- struct TestAddStatic2 : TestAdd<32 > {
132- void exec (std::shared_ptr<const OclModule> &mod, OclContext &ctx ) override {
184+ struct TestAddStatic2 : TestAdd<64 > {
185+ void exec (std::shared_ptr<const OclModule> &mod) override {
133186 assert (mod->isStatic );
134187 StaticExecutor<3 > exec (mod);
135188 exec.arg (buf0);
@@ -146,7 +199,7 @@ TEST(GpuOclRuntime, TestAddStatic) {
146199TEST (GpuOclRuntime, TestAddDynamic) {
147200 GTEST_SKIP () << " Dynamic shapes are not yet supported" ;
148201 struct TestAddDynamic : TestAdd<32 , 64 > {
149- void exec (std::shared_ptr<const OclModule> &mod, OclContext &ctx ) override {
202+ void exec (std::shared_ptr<const OclModule> &mod) override {
150203 assert (!mod->isStatic );
151204 int64_t shape[] = {32 , 64 };
152205 int64_t strides[] = {64 , 1 };
@@ -161,3 +214,15 @@ TEST(GpuOclRuntime, TestAddDynamic) {
161214 } test;
162215 test.test (addDynamic);
163216}
217+
218+ TEST (GpuOclRuntime, TestMatmulAddStatic) {
219+ struct Test : TestMatmulAdd<64 , 128 > {
220+ void exec (std::shared_ptr<const OclModule> &mod) override {
221+ assert (mod->isStatic );
222+ StaticExecutor<3 > exec (mod);
223+ exec (ctx, buf0, buf1, buf2);
224+ assert (exec.isSmall ());
225+ }
226+ } test;
227+ test.test (matmulAddStatic);
228+ }
0 commit comments