Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 96 additions & 26 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22704,36 +22704,107 @@ struct ScatterMultiplySimplify final
}
}

auto status =
detectConstantSetindexScatterOp(scatterOp, /*allowedMultipleUses*/
false,
/*onlyConstantZerosAllowed*/
true, nullptr);
if (!status.ok())
if (!scatterOp.getUniqueIndices()) {
return failure();
}

bool isAllZeros = false, isAllOnes = false;

SplatElementsAttr constSetIndexValue = nullptr;
auto status = detectConstantSetindexScatterOp(
scatterOp, /*allowedMultipleUses*/ false,
[&isAllZeros, &isAllOnes](mlir::Value input) {
isAllZeros = matchPattern(input, m_AnyZeroFloat()) ||
matchPattern(input, m_Zero());
if (!isAllZeros) {
isAllOnes = matchPattern(input, m_OneFloat()) ||
matchPattern(input, m_One());
}
return isAllZeros || isAllOnes;
},
constSetIndexValue);
if (!status.ok()) {
return rewriter.notifyMatchFailure(op, status.message());
}

SmallVector<int64_t> sliceSizes = computeGatherSliceSizes(scatterOp);

auto gatheredValues = stablehlo::GatherOp::create(
rewriter, op.getLoc(), otherValue, scatterOp.getScatterIndices(),
getGatherDims(rewriter.getContext(),
scatterOp.getScatterDimensionNumbers()),
rewriter.getDenseI64ArrayAttr(sliceSizes),
scatterOp.getIndicesAreSortedAttr());
if (isAllZeros) { // non scattered values before zeros
auto gatheredValues = stablehlo::GatherOp::create(
rewriter, op.getLoc(), otherValue, scatterOp.getScatterIndices(),
getGatherDims(rewriter.getContext(),
scatterOp.getScatterDimensionNumbers()),
rewriter.getDenseI64ArrayAttr(sliceSizes),
scatterOp.getIndicesAreSortedAttr());

Value mulRhs;
if (constSetIndexValue) {
mulRhs = stablehlo::ConstantOp::create(
rewriter, op.getLoc(),
constSetIndexValue.resizeSplat(
cast<ShapedType>(gatheredValues.getType())));
} else {
mulRhs = scatterOp.getUpdates()[0];
}
auto newUpdates =
stablehlo::MulOpCreate(rewriter, op.getLoc(), gatheredValues, mulRhs);

auto newUpdates = stablehlo::MulOp::create(
rewriter, op.getLoc(), gatheredValues, scatterOp.getUpdates()[0]);
auto newScatterOp = stablehlo::ScatterOp::create(
rewriter, op.getLoc(), scatterOp.getResultTypes(),
scatterOp.getInputs(), scatterOp.getScatterIndices(),
ValueRange(newUpdates), scatterOp.getScatterDimensionNumbersAttr(),
scatterOp.getIndicesAreSortedAttr(),
scatterOp.getUniqueIndicesAttr());

auto newScatterOp = stablehlo::ScatterOp::create(
rewriter, op.getLoc(), scatterOp.getResultTypes(),
scatterOp.getInputs(), scatterOp.getScatterIndices(),
ValueRange(newUpdates), scatterOp.getScatterDimensionNumbersAttr(),
scatterOp.getIndicesAreSortedAttr(), scatterOp.getUniqueIndicesAttr());
newScatterOp.getUpdateComputation().takeBody(
scatterOp.getUpdateComputation());
rewriter.replaceOp(op, newScatterOp);
auto &updateRegion = newScatterOp.getUpdateComputation();
auto *block = rewriter.createBlock(&updateRegion);
auto elemType = cast<RankedTensorType>(scatterOp.getResultTypes()[0])
.getElementType();
auto argType = RankedTensorType::get({}, elemType);
block->addArgument(argType, op.getLoc());
block->addArgument(argType, op.getLoc());
rewriter.setInsertionPointToStart(block);
stablehlo::ReturnOp::create(rewriter, op.getLoc(), block->getArgument(1));

return success();
rewriter.replaceOp(op, newScatterOp);
return success();
}

if (isAllOnes) { // non-scattered values stay as is
auto newScatterOp = stablehlo::ScatterOp::create(
rewriter, op.getLoc(), scatterOp.getResultTypes(),
ValueRange(otherValue), scatterOp.getScatterIndices(),
scatterOp.getUpdates(), scatterOp.getScatterDimensionNumbersAttr(),
scatterOp.getIndicesAreSortedAttr(),
scatterOp.getUniqueIndicesAttr());

auto &updateRegion = newScatterOp.getUpdateComputation();
auto *block = rewriter.createBlock(&updateRegion);
auto elemType = cast<RankedTensorType>(scatterOp.getResultTypes()[0])
.getElementType();
auto argType = RankedTensorType::get({}, elemType);
block->addArgument(argType, op.getLoc());
block->addArgument(argType, op.getLoc());
rewriter.setInsertionPointToStart(block);

Value mulRhs;
if (constSetIndexValue) {
mulRhs = stablehlo::ConstantOp::create(
rewriter, op.getLoc(),
constSetIndexValue.resizeSplat(
RankedTensorType::get({}, elemType)));
} else {
mulRhs = block->getArgument(1);
}
auto mulOp = stablehlo::MulOp::create(rewriter, op.getLoc(),
block->getArgument(0), mulRhs);
stablehlo::ReturnOp::create(rewriter, op.getLoc(), mulOp.getResult());

rewriter.replaceOp(op, newScatterOp);
return success();
}

return failure();
}
};

Expand Down Expand Up @@ -22807,10 +22878,9 @@ struct UnaryElementwiseScatterSimplify final
if (!scatterOp)
return rewriter.notifyMatchFailure(op, "not a scatter op");

DenseElementsAttr scatterInputAttr;
auto status = detectConstantSetindexScatterOp(
scatterOp, false, /*onlyConstantZerosAllowed*/ false,
&scatterInputAttr);
scatterOp, false,
[](mlir::Value input) { return matchPattern(input, m_Constant()); });
if (!status.ok()) {
return rewriter.notifyMatchFailure(op, status.message());
}
Expand Down
35 changes: 22 additions & 13 deletions src/enzyme_ad/jax/Passes/StructuredTensors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,15 @@ namespace enzyme {

absl::Status detectConstantSetindexScatterOp(stablehlo::ScatterOp scatterOp,
bool allowedMultipleUses,
bool onlyConstantZerosAllowed,
DenseElementsAttr *constAttr) {
InputValidatorFn inputValidator) {
SplatElementsAttr constSetIndexValue = nullptr;
return detectConstantSetindexScatterOp(scatterOp, allowedMultipleUses,
inputValidator, constSetIndexValue);
}

absl::Status detectConstantSetindexScatterOp(
stablehlo::ScatterOp scatterOp, bool allowedMultipleUses,
InputValidatorFn inputValidator, SplatElementsAttr &constSetIndexValue) {
if (scatterOp.getInputs().size() != 1) {
return absl::UnimplementedError(
"Detection not implemented for scatter op with >1 input.");
Expand All @@ -26,20 +33,18 @@ absl::Status detectConstantSetindexScatterOp(stablehlo::ScatterOp scatterOp,

auto checkCommonScatterOp = mlir::stablehlo::CheckCommonScatterOp(scatterOp);

if (!checkCommonScatterOp.isSetindexScatter) {
if (!checkCommonScatterOp.isSetindexScatter &&
!checkCommonScatterOp.isConstantSetindexScatter) {
return absl::InvalidArgumentError("ScatterOp is not a setindex op.");
}

if (checkCommonScatterOp.isConstantSetindexScatter) {
constSetIndexValue = checkCommonScatterOp.constant;
}

auto input = scatterOp.getInputs()[0];
if (onlyConstantZerosAllowed) {
if (matchPattern(input, m_AnyZeroFloat()) ||
matchPattern(input, m_Zero())) {
return absl::OkStatus();
}
} else {
if (matchPattern(input, m_Constant(constAttr))) {
return absl::OkStatus();
}
if (inputValidator(input)) {
return absl::OkStatus();
}

return absl::InvalidArgumentError(
Expand All @@ -49,7 +54,11 @@ absl::Status detectConstantSetindexScatterOp(stablehlo::ScatterOp scatterOp,
// TODO: detect batched diagonal tensors
absl::Status detectDiagonalTensor(stablehlo::ScatterOp scatterOp,
mlir::Value *outUpdates) {
auto status = detectConstantSetindexScatterOp(scatterOp, true, true, nullptr);
auto status =
detectConstantSetindexScatterOp(scatterOp, true, [](mlir::Value input) {
return matchPattern(input, m_AnyZeroFloat()) ||
matchPattern(input, m_Zero());
});
if (!status.ok())
return status;

Expand Down
10 changes: 8 additions & 2 deletions src/enzyme_ad/jax/Passes/StructuredTensors.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,21 @@

#include "llvm/ADT/SetVector.h"

#include <functional>
#include <optional>

namespace mlir {
namespace enzyme {

using InputValidatorFn = std::function<bool(mlir::Value)>;

absl::Status detectConstantSetindexScatterOp(
stablehlo::ScatterOp scatterOp, bool allowedMultipleUses,
InputValidatorFn inputValidator, SplatElementsAttr &constSetIndexValue);

absl::Status detectConstantSetindexScatterOp(stablehlo::ScatterOp scatterOp,
bool allowedMultipleUses,
bool onlyConstantZerosAllowed,
DenseElementsAttr *constAttr);
InputValidatorFn inputValidator);

absl::Status detectDiagonalTensor(stablehlo::ScatterOp scatterOp,
mlir::Value *outUpdates);
Expand Down
36 changes: 27 additions & 9 deletions src/enzyme_ad/jax/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1422,26 +1422,44 @@ getGatherDims(mlir::MLIRContext *ctx,
scatterDimNumbers.getIndexVectorDim());
}

bool isSetindexBlock(mlir::Block *block) {
if (block->getNumArguments() != 2)
bool isSetindexBlockHelper(
mlir::Block *block,
std::function<bool(stablehlo::ReturnOp retOp, Value updateValue)> fn) {
if (block->getNumArguments() != 2) {
return false;

auto updateValue = block->getArgument(1);
}

// The block should have exactly one operation (the return)
if (block->getOperations().size() != 1)
if (block->getOperations().size() != 1) {
return false;
}

auto &returnOp = block->front();
auto stablehloReturnOp = dyn_cast<stablehlo::ReturnOp>(returnOp);
if (!stablehloReturnOp)
if (!stablehloReturnOp) {
return false;
}

if (stablehloReturnOp.getNumOperands() != 1)
if (stablehloReturnOp.getNumOperands() != 1) {
return false;
}

return fn(stablehloReturnOp, block->getArgument(1));
}

// The returned value should be the update value (second argument)
return stablehloReturnOp.getOperand(0) == updateValue;
bool isSetindexBlock(mlir::Block *block) {
return isSetindexBlockHelper(
block, [](stablehlo::ReturnOp retOp, Value updateValue) {
return retOp.getOperand(0) == updateValue;
});
}

bool isConstantSetindexBlock(mlir::Block *block,
mlir::SplatElementsAttr &constant) {
return isSetindexBlockHelper(
block, [&constant](stablehlo::ReturnOp retOp, Value updateValue) {
return matchPattern(retOp.getOperand(0), m_Constant(&constant));
});
}

SmallVector<int64_t> computeGatherSliceSizes(stablehlo::ScatterOp &scatterOp) {
Expand Down
6 changes: 6 additions & 0 deletions src/enzyme_ad/jax/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,8 @@ getGatherDims(mlir::MLIRContext *ctx,
stablehlo::ScatterDimensionNumbersAttr scatterDimNumbers);

bool isSetindexBlock(mlir::Block *block);
bool isConstantSetindexBlock(mlir::Block *block,
mlir::SplatElementsAttr &constant);

// rhs is only considered if commutative is false
template <typename T, bool commutative, bool rhs>
Expand Down Expand Up @@ -1067,6 +1069,8 @@ struct CheckCommonReduceOp {
struct CheckCommonScatterOp {
public:
bool isSetindexScatter;
bool isConstantSetindexScatter;

bool isAddScatter;
bool isMinScatter;
bool isMaxScatter;
Expand All @@ -1087,6 +1091,7 @@ struct CheckCommonScatterOp {

if (!updateComputation.hasOneBlock()) {
isSetindexScatter = false;
isConstantSetindexScatter = false;
isAddScatter = false;
isMinScatter = false;
isMaxScatter = false;
Expand All @@ -1105,6 +1110,7 @@ struct CheckCommonScatterOp {

auto &block = updateComputation.front();
isSetindexScatter = isSetindexBlock(&block);
isConstantSetindexScatter = isConstantSetindexBlock(&block, constant);
isAddScatter = isOnlyOpBlock<stablehlo::AddOp, true, false>(&block);
isMulScatter = isOnlyOpBlock<stablehlo::MulOp, true, false>(&block);
isMinScatter = isOnlyOpBlock<stablehlo::MinOp, true, false>(&block);
Expand Down
4 changes: 2 additions & 2 deletions test/lit_tests/mulscatter.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ func.func @main(%arg0: tensor<4xi64>, %arg1: tensor<6xi64>, %arg2: tensor<1024x1
%4 = stablehlo.reshape %2 : (tensor<6x4xi64>) -> tensor<24x1xi64>
%5 = stablehlo.concatenate %3, %4, dim = 1 : (tensor<24x1xi64>, tensor<24x1xi64>) -> tensor<24x2xi64>
%6 = stablehlo.subtract %5, %c : tensor<24x2xi64>
%7 = "stablehlo.scatter"(%cst_0, %6, %cst) <{scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>}> ({
%7 = "stablehlo.scatter"(%cst_0, %6, %cst) <{scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>, unique_indices = true}> ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
stablehlo.return %arg4 : tensor<f32>
}) : (tensor<1024x1024xf32>, tensor<24x2xi64>, tensor<24xf32>) -> tensor<1024x1024xf32>
Expand All @@ -27,7 +27,7 @@ func.func @main(%arg0: tensor<4xi64>, %arg1: tensor<6xi64>, %arg2: tensor<1024x1
// CHECK: %[[arg2_T:.*]] = stablehlo.transpose %arg2, dims = [1, 0] : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
// CHECK: %[[GATHER:.*]] = "stablehlo.gather"(%[[arg2_T]], %[[SCATTER_INDICES:.*]]) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 1>, slice_sizes = array<i64: 1, 1>}> : (tensor<1024x1024xf32>, tensor<24x2xi64>) -> tensor<24xf32>
// CHECK: %[[MUL:.*]] = stablehlo.multiply %[[GATHER]], %[[CST]] : tensor<24xf32>
// CHECK: %[[SCATTER:.*]] = "stablehlo.scatter"(%[[CST_1]], %[[SCATTER_INDICES]], %[[MUL]]) <{scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>}> ({
// CHECK: %[[SCATTER:.*]] = "stablehlo.scatter"(%[[CST_1]], %[[SCATTER_INDICES]], %[[MUL]]) <{scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>, unique_indices = true}> ({
// CHECK: %[[RESULT:.*]] = stablehlo.transpose %[[SCATTER]], dims = [1, 0] : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
// CHECK: return %[[RESULT]] : tensor<1024x1024xf32>
// CHECK: }
35 changes: 35 additions & 0 deletions test/lit_tests/mulscatterones.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// RUN: enzymexlamlir-opt %s --enzyme-hlo-opt | FileCheck %s

func.func @main(%arg0: tensor<4xi64>, %arg1: tensor<6xi64>, %arg2: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
%cst = stablehlo.constant dense<2.000000e+00> : tensor<24xf32>
%c = stablehlo.constant dense<1> : tensor<24x2xi64>
%cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<1024x1024xf32>
%0 = stablehlo.transpose %arg2, dims = [1, 0] : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
%1 = stablehlo.concatenate %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, dim = 0 : (tensor<4xi64>, tensor<4xi64>, tensor<4xi64>, tensor<4xi64>, tensor<4xi64>, tensor<4xi64>) -> tensor<24xi64>
%2 = stablehlo.broadcast_in_dim %arg1, dims = [0] : (tensor<6xi64>) -> tensor<6x4xi64>
%3 = stablehlo.reshape %1 : (tensor<24xi64>) -> tensor<24x1xi64>
%4 = stablehlo.reshape %2 : (tensor<6x4xi64>) -> tensor<24x1xi64>
%5 = stablehlo.concatenate %3, %4, dim = 1 : (tensor<24x1xi64>, tensor<24x1xi64>) -> tensor<24x2xi64>
%6 = stablehlo.subtract %5, %c : tensor<24x2xi64>
%7 = "stablehlo.scatter"(%cst_0, %6, %cst) <{scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>, unique_indices = true}> ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
stablehlo.return %arg4 : tensor<f32>
}) : (tensor<1024x1024xf32>, tensor<24x2xi64>, tensor<24xf32>) -> tensor<1024x1024xf32>
%8 = stablehlo.multiply %7, %0 : tensor<1024x1024xf32>
%9 = stablehlo.transpose %8, dims = [1, 0] : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
return %9 : tensor<1024x1024xf32>
}

// CHECK: func.func @main(%arg0: tensor<4xi64>, %arg1: tensor<6xi64>, %arg2: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
// CHECK: %[[S_CST:.*]] = stablehlo.constant dense<2.000000e+00> : tensor<f32>
// CHECK: %[[CST:.*]] = stablehlo.constant dense<2.000000e+00> : tensor<24xf32>
// CHECK: %[[CST_0:.*]] = stablehlo.constant dense<1> : tensor<24x2xi64>
// CHECK: %[[arg2_T:.*]] = stablehlo.transpose %arg2, dims = [1, 0] : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
// CHECK: %[[SCATTER:.*]] = "stablehlo.scatter"(%[[arg2_T]], %[[SCATTER_INDICES:.*]], %[[CST]]) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>, unique_indices = true}> ({
// CHECK: ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>):
// CHECK: %[[MUL:.*]] = stablehlo.multiply %[[ARG3]], %[[S_CST]] : tensor<f32>
// CHECK: stablehlo.return %[[MUL]] : tensor<f32>
// CHECK: })
// CHECK: %[[RESULT:.*]] = stablehlo.transpose %[[SCATTER]], dims = [1, 0] : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
// CHECK: return %[[RESULT]] : tensor<1024x1024xf32>
// CHECK: }
6 changes: 3 additions & 3 deletions test/lit_tests/unaryscatter.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func.func @convertscatter(%arg0: tensor<5x4xf32>, %arg1: tensor<5xui32>) -> tens
%5 = stablehlo.reshape %3 : (tensor<5xi64>) -> tensor<5x1xi64>
%6 = stablehlo.concatenate %4, %5, dim = 1 : (tensor<5x1xi64>, tensor<5x1xi64>) -> tensor<5x2xi64>
%7 = stablehlo.remainder %6, %c : tensor<5x2xi64>
%8 = "stablehlo.scatter"(%c_3, %7, %c_1) <{scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>}> ({
%8 = "stablehlo.scatter"(%c_3, %7, %c_1) <{scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>, unique_indices = true}> ({
^bb0(%arg2: tensor<i1>, %arg3: tensor<i1>):
stablehlo.return %arg3 : tensor<i1>
}) : (tensor<4x5xi1>, tensor<5x2xi64>, tensor<5xi1>) -> tensor<4x5xi1>
Expand All @@ -111,8 +111,8 @@ func.func @convertscatter(%arg0: tensor<5x4xf32>, %arg1: tensor<5xui32>) -> tens
// CHECK-NEXT: %5 = stablehlo.reshape %3 : (tensor<5xi64>) -> tensor<5x1xi64>
// CHECK-NEXT: %6 = stablehlo.concatenate %4, %5, dim = 1 : (tensor<5x1xi64>, tensor<5x1xi64>) -> tensor<5x2xi64>
// CHECK-NEXT: %7 = stablehlo.remainder %6, %c : tensor<5x2xi64>
// CHECK-NEXT: %8 = "stablehlo.gather"(%0, %7) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 1>, slice_sizes = array<i64: 1, 1>}> : (tensor<4x5xf32>, tensor<5x2xi64>) -> tensor<5xf32>
// CHECK-NEXT: %9 = "stablehlo.scatter"(%cst, %7, %8) <{scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>}> ({
// CHECK-NEXT: %8 = "stablehlo.gather"(%0, %7) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1>}> : (tensor<4x5xf32>, tensor<5x2xi64>) -> tensor<5xf32>
// CHECK-NEXT: %9 = "stablehlo.scatter"(%cst, %7, %8) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>, unique_indices = true}> ({
// CHECK-NEXT: ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
// CHECK-NEXT: stablehlo.return %arg3 : tensor<f32>
// CHECK-NEXT: }) : (tensor<4x5xf32>, tensor<5x2xi64>, tensor<5xf32>) -> tensor<4x5xf32>
Expand Down
Loading