Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ class DmlOperatorEinSum : public DmlOperator, public EinSumHelper
}
}
tensorDesc.SetDimensionsAndStrides(newSizes, newStrides);
tensorDesc.EnsureDimensionCount(1, TensorAxis::RightAligned);
tensorDesc.EnsureMinimumDimensionCount(1, TensorAxis::RightAligned);
}

// Reproject a tensor to the given axis arrangement.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,28 @@ class DmlOperatorMatMul : public DmlOperator
std::vector<DimensionType> inputShape0 = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(0);
std::vector<DimensionType> inputShape1 = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(1);
std::vector<DimensionType> outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0);
std::vector<DimensionType> inputShape0Broadcasted;
std::vector<DimensionType> inputShape1Broadcasted;

OperatorHelper::MatMulShapeMapping(inputShape0, inputShape1, outputShape);
OperatorHelper::MatMulShapeMapping(inputShape0, inputShape1, outputShape, inputShape0Broadcasted, inputShape1Broadcasted);

// Initialize the input descriptions with broadcasting
m_inputTensorDescs[0] = CreateTensorDescFromInput(kernelInfo, 0, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, inputShape0);
m_inputTensorDescs[1] = CreateTensorDescFromInput(kernelInfo, 1, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, inputShape1);
// Initialize the input descriptions without broadcasting yet, since MatMul has special rules where broadcasting the
// original shape (notably when 1D) to the output shape would mess up because the dimensions are shifted.
m_inputTensorDescs[0] = CreateTensorDescFromInput(kernelInfo, 0);
m_inputTensorDescs[1] = CreateTensorDescFromInput(kernelInfo, 1);

// Initialize the output description while overriding the shape
m_outputTensorDescs[0] = CreateTensorDescFromOutput(kernelInfo, 0, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, outputShape);

// Broadcast the inputs to their broadcasted shapes.
m_inputTensorDescs[0].SetBroadcastedShape(inputShape0Broadcasted, inputShape0, outputShape.size());
m_inputTensorDescs[1].SetBroadcastedShape(inputShape1Broadcasted, inputShape1, outputShape.size());

// DirectML only supports ranks up to 4D for GEMM, and so any leading dimensions must be folded.
m_inputTensorDescs[0].SetDimensionCount(4, TensorAxis::RightAligned, /*foldEndDimensions*/ true);
m_inputTensorDescs[1].SetDimensionCount(4, TensorAxis::RightAligned, /*foldEndDimensions*/ true);
m_outputTensorDescs[0].SetDimensionCount(4, TensorAxis::RightAligned, /*foldEndDimensions*/ true);

std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,20 @@ class DmlOperatorMatMulInteger : public DmlOperator
std::vector<DimensionType> inputShape0 = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(0);
std::vector<DimensionType> inputShape1 = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(1);
std::vector<DimensionType> outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0);
std::vector<DimensionType> inputShape0Broadcasted;
std::vector<DimensionType> inputShape1Broadcasted;

OperatorHelper::MatMulShapeMapping(inputShape0, inputShape1, outputShape);
OperatorHelper::MatMulShapeMapping(inputShape0, inputShape1, outputShape, inputShape0Broadcasted, inputShape1Broadcasted);

// Initialize the input descriptions with broadcasting
m_inputTensorDescs[IN_A] = CreateTensorDescFromInput(kernelInfo, 0/*OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, inputShape0);
m_inputTensorDescs[IN_B] = CreateTensorDescFromInput(kernelInfo, 1/*OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, inputShape1);
// Initialize the input descriptions without broadcasting
m_inputTensorDescs[IN_A] = CreateTensorDescFromInput(kernelInfo, 0/*OnnxIndex*/);
m_inputTensorDescs[IN_B] = CreateTensorDescFromInput(kernelInfo, 1/*OnnxIndex*/);

// Broadcast the inputs to their broadcasted shapes.
m_inputTensorDescs[IN_A].SetBroadcastedShape(inputShape0Broadcasted, inputShape0, outputShape.size());
m_inputTensorDescs[IN_B].SetBroadcastedShape(inputShape1Broadcasted, inputShape1, outputShape.size());
m_inputTensorDescs[IN_A].SetDimensionCount(4, TensorAxis::RightAligned, /*foldEndDimensions*/ true);
m_inputTensorDescs[IN_B].SetDimensionCount(4, TensorAxis::RightAligned, /*foldEndDimensions*/ true);

uint32_t dmlDimSize = m_inputTensorDescs[0].GetDimensionCount();
// Resize the A ZeroPoint to be the same dimension as the input tensor.
Expand All @@ -49,6 +57,7 @@ class DmlOperatorMatMulInteger : public DmlOperator

// Initialize the output description while overriding the shape
m_outputTensorDescs[0] = CreateTensorDescFromOutput(kernelInfo, 0, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, outputShape);
m_outputTensorDescs[0].SetDimensionCount(4, TensorAxis::RightAligned, /*foldEndDimensions*/ true);

std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,37 @@ class DmlOperatorMatMulIntegerToFloat : public DmlOperator
std::vector<DimensionType> inputShape0 = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(OrtInputTensors::ortA);
std::vector<DimensionType> inputShape1 = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(OrtInputTensors::ortB);
std::vector<DimensionType> outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0);
std::vector<DimensionType> inputShape0Broadcasted;
std::vector<DimensionType> inputShape1Broadcasted;

OperatorHelper::MatMulShapeMapping(inputShape0, inputShape1, outputShape);
OperatorHelper::MatMulShapeMapping(inputShape0, inputShape1, outputShape, inputShape0Broadcasted, inputShape1Broadcasted);

// Initialize the input descriptions with broadcasting
m_inputTensorDescs[DmlInputIndex::dmlA] = CreateTensorDescFromInput(kernelInfo, OrtInputTensors::ortA, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, inputShape0);
m_inputTensorDescs[DmlInputIndex::dmlB] = CreateTensorDescFromInput(kernelInfo, OrtInputTensors::ortB, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, inputShape1);
// Initialize the input descriptions without broadcasting
m_inputTensorDescs[DmlInputIndex::dmlA] = CreateTensorDescFromInput(kernelInfo, OrtInputTensors::ortA);
m_inputTensorDescs[DmlInputIndex::dmlB] = CreateTensorDescFromInput(kernelInfo, OrtInputTensors::ortB);

// Broadcast the inputs to their broadcasted shapes.
m_inputTensorDescs[DmlInputIndex::dmlA].SetBroadcastedShape(inputShape0Broadcasted, inputShape0, outputShape.size());
m_inputTensorDescs[DmlInputIndex::dmlB].SetBroadcastedShape(inputShape1Broadcasted, inputShape1, outputShape.size());
m_inputTensorDescs[DmlInputIndex::dmlA].SetDimensionCount(4, TensorAxis::RightAligned, /*foldEndDimensions*/ true);
m_inputTensorDescs[DmlInputIndex::dmlB].SetDimensionCount(4, TensorAxis::RightAligned, /*foldEndDimensions*/ true);

uint32_t dmlDimSize = m_inputTensorDescs[DmlInputIndex::dmlA].GetDimensionCount();

// Broadcast Bias tensor to the shape of the output tensor.
if(kernelInfo.IsInputValid(OrtInputTensors::ortBias)) {
m_inputTensorDescs[DmlInputIndex::dmlBias] = CreateTensorDescFromInput(kernelInfo, OrtInputTensors::ortBias, TensorAxis::DoNotCoerce,
TensorAxis::W, TensorAxis::RightAligned, outputShape);
if (kernelInfo.IsInputValid(OrtInputTensors::ortBias))
{
m_inputTensorDescs[DmlInputIndex::dmlBias] = CreateTensorDescFromInput(
kernelInfo,
OrtInputTensors::ortBias,
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
outputShape,
dmlDimSize
);
}

uint32_t dmlDimSize = m_inputTensorDescs[DmlInputIndex::dmlA].GetDimensionCount();
// Resize the A Scale to be the same dimension as the input tensor.
// The 1D tensor needs to be moved to the H channel.
m_inputTensorDescs[DmlInputIndex::dmlAScale] = CreateTensorDescFromInput(
Expand Down Expand Up @@ -87,6 +104,7 @@ class DmlOperatorMatMulIntegerToFloat : public DmlOperator

// Initialize the output description while overriding the shape
m_outputTensorDescs[0] = CreateTensorDescFromOutput(kernelInfo, 0, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, outputShape);
m_outputTensorDescs[0].SetDimensionCount(4, TensorAxis::RightAligned, /*foldEndDimensions*/ true);

std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,20 @@ class DmlOperatorQLinearMatMul : public DmlOperator
std::vector<DimensionType> inputShape0 = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(0/*A OnnxIndex*/);
std::vector<DimensionType> inputShape1 = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(3/*B OnnxIndex*/);
std::vector<DimensionType> outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0);
std::vector<DimensionType> inputShape0Broadcasted;
std::vector<DimensionType> inputShape1Broadcasted;

OperatorHelper::MatMulShapeMapping(inputShape0, inputShape1, outputShape);
OperatorHelper::MatMulShapeMapping(inputShape0, inputShape1, outputShape, inputShape0Broadcasted, inputShape1Broadcasted);

// Initialize the input descriptions with broadcasting
m_inputTensorDescs[IN_A] = CreateTensorDescFromInput(kernelInfo, 0/*A OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, inputShape0);
m_inputTensorDescs[IN_B] = CreateTensorDescFromInput(kernelInfo, 3/*B OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, inputShape1);
// Initialize the input descriptions without broadcasting
m_inputTensorDescs[IN_A] = CreateTensorDescFromInput(kernelInfo, 0/*A OnnxIndex*/);
m_inputTensorDescs[IN_B] = CreateTensorDescFromInput(kernelInfo, 3/*B OnnxIndex*/);

// Broadcast the inputs to their broadcasted shapes.
m_inputTensorDescs[IN_A].SetBroadcastedShape(inputShape0Broadcasted, inputShape0, outputShape.size());
m_inputTensorDescs[IN_B].SetBroadcastedShape(inputShape1Broadcasted, inputShape1, outputShape.size());
m_inputTensorDescs[IN_A].SetDimensionCount(4, TensorAxis::RightAligned, /*foldEndDimensions*/ true);
m_inputTensorDescs[IN_B].SetDimensionCount(4, TensorAxis::RightAligned, /*foldEndDimensions*/ true);

uint32_t dmlDimSize = m_inputTensorDescs[0].GetDimensionCount();
// Resize the A Scale to be the same dimension as the input tensor.
Expand Down Expand Up @@ -89,6 +97,7 @@ class DmlOperatorQLinearMatMul : public DmlOperator

// Initialize the output description while overriding the shape
m_outputTensorDescs[0] = CreateTensorDescFromOutput(kernelInfo, 0, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, outputShape);
m_outputTensorDescs[0].SetDimensionCount(4, TensorAxis::RightAligned, /*foldEndDimensions*/ true);

std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
Expand Down
Loading
Loading