1919#include " gtest/gtest.h"
2020#include < memory>
2121
22- #include < mlir/Dialect/GPU/Transforms/Passes.h>
23-
24- #include " gc/Transforms/Passes.h"
2522#include " mlir/Target/LLVMIR/Export.h"
2623#include " mlir/Target/LLVMIR/ModuleTranslation.h"
2724#include < CL/cl_ext.h>
@@ -59,40 +56,69 @@ module @test {
5956}
6057)mlir" ;
6158
62- template <unsigned N, unsigned M = N> struct TestAdd {
59+ constexpr char matmulAddStatic[] = R"mlir(
60+ module @fragment_name attributes {"#dlti.sys_spec" = #dlti.target_system_spec<"CPU" : #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 32 : i32>>>} {
61+ func.func @entry(%arg0: memref<64x128xf16>, %arg1: memref<128x128xf16>, %arg2: memref<64x128xf16>) {
62+ %0 = bufferization.to_tensor %arg0 restrict : memref<64x128xf16>
63+ %1 = bufferization.to_tensor %arg1 restrict : memref<128x128xf16>
64+ %2 = tensor.empty() : tensor<64x128xf16>
65+ %cst = arith.constant 0.000000e+00 : f16
66+ %3 = linalg.fill ins(%cst : f16) outs(%2 : tensor<64x128xf16>) -> tensor<64x128xf16>
67+ %4 = linalg.matmul_transpose_b ins(%0, %1 : tensor<64x128xf16>, tensor<128x128xf16>) outs(%3 : tensor<64x128xf16>) -> tensor<64x128xf16>
68+ %5 = tensor.empty() : tensor<64x128xf16>
69+ %6 = linalg.add ins(%4, %0 : tensor<64x128xf16>, tensor<64x128xf16>) outs(%5 : tensor<64x128xf16>) -> tensor<64x128xf16>
70+ bufferization.materialize_in_destination %6 in restrict writable %arg2 : (tensor<64x128xf16>, memref<64x128xf16>) -> ()
71+ return
72+ }
73+ }
74+ )mlir" ;
75+
76+ struct TestBase {
6377 OclRuntime runtime = gcGetOrReport(OclRuntime::get());
6478 cl_command_queue queue = gcGetOrReport(runtime.createQueue());
79+ OclContext ctx{runtime, queue};
80+ MLIRContext mlirCtx{gc::initCompilerAndGetDialects ()};
81+
82+ virtual void exec (std::shared_ptr<const OclModule> &mod, OclContext &ctx) = 0;
83+
84+ virtual ~TestBase () { gcGetOrReport (runtime.releaseQueue (queue)); }
85+
86+ OwningOpRef<ModuleOp> parse (const char *code) {
87+ std::unique_ptr<llvm::MemoryBuffer> memBuf =
88+ llvm::MemoryBuffer::getMemBuffer (code);
89+ llvm::SourceMgr srcMgr;
90+ srcMgr.AddNewSourceBuffer (std::move (memBuf), SMLoc ());
91+ return parseSourceFile<ModuleOp>(srcMgr, &mlirCtx);
92+ }
93+ };
6594
95+ template <unsigned N, unsigned M = N> struct TestAdd : TestBase {
6696 static constexpr unsigned size = N * M;
6797 float *buf0 = gcGetOrReport(runtime.usmNewDev<float >(size));
6898 float *buf1 = gcGetOrReport(runtime.usmNewDev<float >(size));
6999 float *buf2 = gcGetOrReport(runtime.usmNewShared<float >(size));
70- MLIRContext mlirCtx{gc::initCompilerAndGetDialects ()};
71- float cpuBuf1[size] = {};
72- float cpuBuf2[size] = {};
73100
74- explicit TestAdd () { std::fill (cpuBuf1, cpuBuf1 + size, 2 .0f ); }
101+ explicit TestAdd () {
102+ float cpuBuf[size];
103+ std::fill (cpuBuf, cpuBuf + size, 2 .0f );
104+ assert (runtime.usmCpy (ctx, cpuBuf, buf0, size));
105+ assert (runtime.usmCpy (ctx, cpuBuf, buf1, size));
106+ }
75107
76- virtual ~TestAdd () {
77- gcGetOrReport (runtime.releaseQueue (queue));
108+ ~TestAdd () override {
78109 assert (runtime.usmFree (buf0));
79110 assert (runtime.usmFree (buf1));
80111 assert (runtime.usmFree (buf2));
112+ TestBase::~TestBase ();
81113 }
82114
83- virtual void exec (std::shared_ptr<const OclModule> &mod, OclContext &ctx) = 0;
84-
85115 void test (const char *code) {
86- OclContext ctx (runtime, queue);
87- assert (runtime.usmCpy (ctx, cpuBuf1, buf0, size));
88- assert (runtime.usmCpy (ctx, cpuBuf1, buf1, size));
89-
90116 OclModuleBuilder builder (parse (code));
91117 auto mod = gcGetOrReport (builder.build (runtime));
92-
93118 exec (mod, ctx);
94119
95- assert (runtime.usmCpy (ctx, buf2, cpuBuf2, size));
120+ float cpuBuf[size];
121+ assert (runtime.usmCpy (ctx, buf2, cpuBuf, size));
96122 gcGetOrReport (ctx.finish ());
97123
98124 for (unsigned i = 0 ; i < size; i++) {
@@ -101,18 +127,46 @@ template <unsigned N, unsigned M = N> struct TestAdd {
101127 }
102128 // std::cout << "\n";
103129
104- for (float i : cpuBuf2 ) {
130+ for (float i : cpuBuf ) {
105131 // std::cout << cpuBuf2[i] << " ";
106132 assert (i == 4 .0f );
107133 }
108134 }
135+ };
109136
110- OwningOpRef<ModuleOp> parse (const char *code) {
111- std::unique_ptr<llvm::MemoryBuffer> memBuf =
112- llvm::MemoryBuffer::getMemBuffer (code);
113- llvm::SourceMgr srcMgr;
114- srcMgr.AddNewSourceBuffer (std::move (memBuf), SMLoc ());
115- return parseSourceFile<ModuleOp>(srcMgr, &mlirCtx);
137+ template <unsigned N, unsigned M = N> struct TestMatmulAdd : TestBase {
138+ static constexpr unsigned size1 = N * M;
139+ static constexpr unsigned size2 = M * M;
140+ cl_half *buf0 = gcGetOrReport(runtime.usmNewDev<cl_half>(size1));
141+ cl_half *buf1 = gcGetOrReport(runtime.usmNewDev<cl_half>(size2));
142+ cl_half *buf2 = gcGetOrReport(runtime.usmNewShared<cl_half>(size1));
143+
144+ explicit TestMatmulAdd () {
145+ cl_half cpuBuf[size2];
146+ std::fill (cpuBuf, cpuBuf + size2, 2 .0f );
147+ assert (runtime.usmCpy (ctx, cpuBuf, buf0, size1));
148+ assert (runtime.usmCpy (ctx, cpuBuf, buf1, size2));
149+ }
150+
151+ ~TestMatmulAdd () override {
152+ assert (runtime.usmFree (buf0));
153+ assert (runtime.usmFree (buf1));
154+ assert (runtime.usmFree (buf2));
155+ TestBase::~TestBase ();
156+ }
157+
158+ void test (const char *code) {
159+ OclModuleBuilder builder (parse (code));
160+ auto mod = gcGetOrReport (builder.build (runtime));
161+ exec (mod, ctx);
162+
163+ gcGetOrReport (ctx.finish ());
164+
165+ for (unsigned i = 0 ; i < size2; i++) {
166+ // std::cout << buf2[i] << " ";
167+ assert (buf2[i] == 4 .0f );
168+ }
169+ // std::cout << "\n";
116170 }
117171};
118172
@@ -143,6 +197,19 @@ TEST(GpuOclRuntime, TestAddStatic) {
143197 test2.test (addStatic);
144198}
145199
200+ TEST (GpuOclRuntime, TestMatmulAddStatic) {
201+ GTEST_SKIP () << " Temporary disabled until #344 is implemented" ;
202+ struct Test : TestMatmulAdd<64 , 128 > {
203+ void exec (std::shared_ptr<const OclModule> &mod, OclContext &ctx) override {
204+ assert (mod->isStatic );
205+ StaticExecutor<3 > exec (mod);
206+ exec (ctx, buf0, buf1, buf2);
207+ assert (exec.isSmall ());
208+ }
209+ } test;
210+ test.test (matmulAddStatic);
211+ }
212+
146213TEST (GpuOclRuntime, TestAddDynamic) {
147214 GTEST_SKIP () << " Dynamic shapes are not yet supported" ;
148215 struct TestAddDynamic : TestAdd<32 , 64 > {
0 commit comments