@@ -877,3 +877,67 @@ hal.executable public @main {
877877// CHECK: vector.insert_strided_slice %[[C_70_4]], {{.*}}offsets = [7, 0, 0, 0, 0, 0]{{.*}} : vector<4xf32> into vector<8x1x2x1x1x4xf32>
878878// CHECK: vector.insert_strided_slice %[[C_71_4]], {{.*}}offsets = [7, 0, 1, 0, 0, 0]{{.*}} : vector<4xf32> into vector<8x1x2x1x1x4xf32>
879879// CHECK: vector.transfer_write
880+
881+ // -----
882+
883+ #layout = #hal.pipeline.layout <bindings = [
884+ #hal.pipeline.binding <storage_buffer , " ReadOnly|Indirect" >,
885+ #hal.pipeline.binding <storage_buffer , " ReadOnly|Indirect" >,
886+ #hal.pipeline.binding <storage_buffer , Indirect >
887+ ], flags = Indirect >
888+
889+ #lowering_config = #iree_gpu.lowering_config <{
890+ promote_operands = [0 , 1 ],
891+ reduction = [0 , 0 , 4 ],
892+ thread = [1 , 4 , 0 ],
893+ workgroup = [1 , 128 , 0 ]
894+ }>
895+ #translation_info = #iree_codegen.translation_info <LLVMGPUTileAndFuse workgroup_size = [32 , 1 , 1 ] subgroup_size = 32 >
896+
897+ hal.executable public @main {
898+ hal.executable.variant public @cuda_nvptx_fb target (<" cuda" , " cuda-nvptx-fb" >) {
899+ hal.executable.export public @small_m_matmul ordinal (0 ) layout (#layout ) {
900+ ^bb0 (%arg0: !hal.device ):
901+ %x , %y , %z = flow.dispatch.workgroup_count_from_slice
902+ hal.return %x , %y , %z : index , index , index
903+ }
904+ builtin.module {
905+ func.func @small_m_matmul () attributes {translation_info = #translation_info } {
906+ %cst = arith.constant 0.000000e+00 : f32
907+ %c0 = arith.constant 0 : index
908+ %0 = hal.interface.binding.subspan layout (#layout ) binding (0 ) alignment (64 ) offset (%c0 ) flags (" ReadOnly|Indirect" ) : !flow.dispatch.tensor <readonly :tensor <4 x1000 xf32 >>
909+ %1 = hal.interface.binding.subspan layout (#layout ) binding (1 ) alignment (64 ) offset (%c0 ) flags (" ReadOnly|Indirect" ) : !flow.dispatch.tensor <readonly :tensor <1000 x512 xf32 >>
910+ %2 = hal.interface.binding.subspan layout (#layout ) binding (2 ) alignment (64 ) offset (%c0 ) flags (Indirect ) : !flow.dispatch.tensor <writeonly :tensor <4 x512 xf32 >>
911+ %3 = flow.dispatch.tensor.load %0 , offsets = [0 , 0 ], sizes = [4 , 1000 ], strides = [1 , 1 ] : !flow.dispatch.tensor <readonly :tensor <4 x1000 xf32 >> -> tensor <4 x1000 xf32 >
912+ %4 = flow.dispatch.tensor.load %1 , offsets = [0 , 0 ], sizes = [1000 , 512 ], strides = [1 , 1 ] : !flow.dispatch.tensor <readonly :tensor <1000 x512 xf32 >> -> tensor <1000 x512 xf32 >
913+ %5 = tensor.empty () : tensor <4 x512 xf32 >
914+ %6 = linalg.fill ins (%cst : f32 ) outs (%5 : tensor <4 x512 xf32 >) -> tensor <4 x512 xf32 >
915+ %7 = linalg.matmul {lowering_config = #lowering_config }
916+ ins (%3 , %4 : tensor <4 x1000 xf32 >, tensor <1000 x512 xf32 >)
917+ outs (%6 : tensor <4 x512 xf32 >) -> tensor <4 x512 xf32 >
918+ flow.dispatch.tensor.store %7 , %2 , offsets = [0 , 0 ], sizes = [4 , 512 ], strides = [1 , 1 ] : tensor <4 x512 xf32 > -> !flow.dispatch.tensor <writeonly :tensor <4 x512 xf32 >>
919+ return
920+ }
921+ }
922+ }
923+ }
924+
925+ // CHECK-LABEL: func @small_m_matmul
926+ // CHECK-DAG: %[[B0:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0)
927+ // CHECK-DAG: %[[B1:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1)
928+ // CHECK-DAG: %[[B2:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2)
929+ // CHECK-DAG: %[[LHS_ALLOC:.+]] = memref.alloc() : memref<1x6xf32, #gpu.address_space<workgroup>>
930+ // CHECK-DAG: %[[RHS_ALLOC:.+]] = memref.alloc() : memref<4x130xf32, #gpu.address_space<workgroup>>
931+ // CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %c0 to %c1000 step %c4 {{.*}} -> (vector<1x4xf32>)
932+ // CHECK: gpu.barrier
933+
934+ // TODO: The fact that this read gets hoisted out of the subsequent for loop
935+ // is a bug in LICM that does no verification that the loop has at least one
936+ // trip.
937+ // CHECK: %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<4xf32>
938+ // CHECK: scf.for %{{.*}} = %{{.*}} to %c1 step %c32
939+ // CHECK-NEXT: vector.transfer_write %[[LHS_RD]], %[[LHS_ALLOC]]
940+ // CHECK: gpu.barrier
941+ // CHECK-DAG: %[[LHS_MM:.+]] = vector.transfer_read %[[LHS_ALLOC]]{{.*}} vector<4xf32>
942+ // CHECK-DAG: %[[RHS_MM:.+]] = vector.transfer_read %[[RHS_ALLOC]]{{.*}} vector<4x4xf32>
943+ // CHECK: vector.contract {{.*}} %[[LHS_MM]], %[[RHS_MM]]
0 commit comments