@@ -62,16 +62,16 @@ module @test {
6262
6363constexpr char matmulAddStatic[] = R"mlir(
6464module @fragment_name attributes {"#dlti.sys_spec" = #dlti.target_system_spec<"CPU" : #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 32 : i32>>>} {
65- func.func @entry(%arg0: memref<64x128xf32 >, %arg1: memref<128x128xf32 >, %arg2: memref<64x128xf32 >) {
66- %0 = bufferization.to_tensor %arg0 restrict : memref<64x128xf32 >
67- %1 = bufferization.to_tensor %arg1 restrict : memref<128x128xf32 >
68- %2 = tensor.empty() : tensor<64x128xf32 >
69- %cst = arith.constant 0.000000e+00 : f32
70- %3 = linalg.fill ins(%cst : f32 ) outs(%2 : tensor<64x128xf32 >) -> tensor<64x128xf32 >
71- %4 = linalg.matmul_transpose_b ins(%0, %1 : tensor<64x128xf32 >, tensor<128x128xf32 >) outs(%3 : tensor<64x128xf32 >) -> tensor<64x128xf32 >
72- %5 = tensor.empty() : tensor<64x128xf32 >
73- %6 = linalg.add ins(%4, %0 : tensor<64x128xf32 >, tensor<64x128xf32 >) outs(%5 : tensor<64x128xf32 >) -> tensor<64x128xf32 >
74- bufferization.materialize_in_destination %6 in restrict writable %arg2 : (tensor<64x128xf32 >, memref<64x128xf32 >) -> ()
65+ func.func @entry(%arg0: memref<64x128xf16 >, %arg1: memref<128x128xf16 >, %arg2: memref<64x128xf16 >) {
66+ %0 = bufferization.to_tensor %arg0 restrict : memref<64x128xf16 >
67+ %1 = bufferization.to_tensor %arg1 restrict : memref<128x128xf16 >
68+ %2 = tensor.empty() : tensor<64x128xf16 >
69+ %cst = arith.constant 0.000000e+00 : f16
70+ %3 = linalg.fill ins(%cst : f16 ) outs(%2 : tensor<64x128xf16 >) -> tensor<64x128xf16 >
71+ %4 = linalg.matmul_transpose_b ins(%0, %1 : tensor<64x128xf16 >, tensor<128x128xf16 >) outs(%3 : tensor<64x128xf16 >) -> tensor<64x128xf16 >
72+ %5 = tensor.empty() : tensor<64x128xf16 >
73+ %6 = linalg.add ins(%4, %0 : tensor<64x128xf16 >, tensor<64x128xf16 >) outs(%5 : tensor<64x128xf16 >) -> tensor<64x128xf16 >
74+ bufferization.materialize_in_destination %6 in restrict writable %arg2 : (tensor<64x128xf16 >, memref<64x128xf16 >) -> ()
7575 return
7676 }
7777}
@@ -141,13 +141,13 @@ template <unsigned N, unsigned M = N> struct TestAdd : TestBase {
141141template <unsigned N, unsigned M = N> struct TestMatmulAdd : TestBase {
142142 static constexpr unsigned size1 = N * M;
143143 static constexpr unsigned size2 = M * M;
144- float *buf0 = gcGetOrReport(runtime.usmNewDev<float >(size1));
145- float *buf1 = gcGetOrReport(runtime.usmNewDev<float >(size2));
146- float *buf2 = gcGetOrReport(runtime.usmNewShared<float >(size1));
144+ cl_half *buf0 = gcGetOrReport(runtime.usmNewDev<cl_half >(size1));
145+ cl_half *buf1 = gcGetOrReport(runtime.usmNewDev<cl_half >(size2));
146+ cl_half *buf2 = gcGetOrReport(runtime.usmNewShared<cl_half >(size1));
147147
148148 explicit TestMatmulAdd () {
149- float cpuBuf[size2];
150- std::fill (cpuBuf, cpuBuf + size2, 2 );
149+ cl_half cpuBuf[size2];
150+ std::fill (cpuBuf, cpuBuf + size2, 14336 );
151151 assert (runtime.usmCpy (ctx, cpuBuf, buf0, size1));
152152 assert (runtime.usmCpy (ctx, cpuBuf, buf1, size2));
153153 gcGetOrReport (ctx.finish ());
@@ -167,7 +167,7 @@ template <unsigned N, unsigned M = N> struct TestMatmulAdd : TestBase {
167167 gcGetOrReport (ctx.finish ());
168168 for (unsigned i = 0 ; i < size1; i++) {
169169 // std::cout << buf2[i] << " ";
170- assert (buf2[i] == 514 );
170+ assert (buf2[i] == 20496 );
171171 }
172172 // std::cout << "\n";
173173 }
0 commit comments