Skip to content

Commit ab7a5b2

Browse files
committed
address comments
Signed-off-by: Max Dawkins <[email protected]>
1 parent a149504 commit ab7a5b2

File tree

2 files changed

+62
-2
lines changed

2 files changed

+62
-2
lines changed

mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -637,9 +637,8 @@ struct UnPackOpTiling
637637
// interchange to map sizes and offsets to the original input.
638638
int64_t outputRank = unPackOp.getDestRank();
639639
ReifiedRankedShapedTypeDims reifiedReturnShapes;
640-
if (failed(reifyResultShapes(b, unPackOp, reifiedReturnShapes))) {
640+
if (failed(reifyResultShapes(b, unPackOp, reifiedReturnShapes)))
641641
return failure();
642-
}
643642
SmallVector<OpFoldResult> outputMixedSizes = reifiedReturnShapes.front();
644643
SmallVector<OpFoldResult> origOffsets(destOffsets);
645644
SmallVector<OpFoldResult> origSizes(destSizes);

mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,67 @@ module attributes {transform.with_named_sequence} {
320320

321321
// -----
322322

323+
#map = affine_map<(d0, d1) -> (d0, d1)>
324+
module {
325+
func.func @fuse_unaligned_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2047xf32> {
326+
%c4 = arith.constant 4 : index
327+
%c64 = arith.constant 64 : index
328+
%c0 = arith.constant 0 : index
329+
%1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
330+
%extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
331+
%3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
332+
^bb0(%in: f32, %in_16: f32, %out: f32):
333+
%13 = arith.mulf %in, %in_16 : f32
334+
%14 = arith.addf %out, %13 : f32
335+
linalg.yield %14 : f32
336+
} -> tensor<32x32xf32>
337+
scf.forall.in_parallel {
338+
tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
339+
}
340+
}
341+
%output = tensor.empty() : tensor<2047xf32>
342+
%unpack = tensor.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2047xf32>
343+
return %unpack : tensor<2047xf32>
344+
}
345+
}
346+
347+
module attributes {transform.with_named_sequence} {
348+
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
349+
%slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
350+
: (!transform.any_op) -> !transform.any_op
351+
%a, %b = transform.test.fuse_consumer %slice_op
352+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
353+
transform.yield
354+
}
355+
}
356+
// CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
357+
// CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2047)>
358+
// CHECK: func.func @fuse_unaligned_unpack_consumer_into_scf_forall(
359+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
360+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
361+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
362+
// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<2047xf32>
363+
// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) = (0, 0) to (64, 32) step (32, 32)
364+
// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[UNPACK_OUT_ARG:.*]] = %[[OUT_INIT]])
365+
// CHECK-SAME: {
366+
// CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
367+
// CHECK: %[[GENERIC_OUT:.*]] = linalg.generic
368+
// CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] :
369+
// CHECK-DAG: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_OFFSET_MAP]](%[[IV1]])
370+
// CHECK-DAG: %[[UNPACK_RESULT_SIZE:.*]] = affine.min #[[UNPACK_RESULT_SIZE_MAP]](%[[IV1]])
371+
// CHECK: %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
372+
// CHECK: %[[TILED_UNPACK_OUT:.*]] = tensor.unpack %[[GENERIC_OUT]]
373+
// CHECK-SAME: outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32]
374+
// CHECK-SAME: into %[[TILED_UNPACK_DEST]]
375+
// CHECK: scf.forall.in_parallel {
376+
// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
377+
// CHECK: tensor.parallel_insert_slice %[[TILED_UNPACK_OUT]] into %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
378+
// CHECK: }
379+
// CHECK: }
380+
// CHECK: return %[[FINAL_RESULT]]#1 :
381+
382+
// -----
383+
323384
#map = affine_map<(d0, d1) -> (d0, d1)>
324385
module {
325386
func.func @fuse_pack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<4x32x16xf32> {

0 commit comments

Comments
 (0)