Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ class DmlOperatorEinSum : public DmlOperator, public EinSumHelper
}
}
tensorDesc.SetDimensionsAndStrides(newSizes, newStrides);
tensorDesc.EnsureDimensionCount(1, TensorAxis::RightAligned);
tensorDesc.EnsureMinimumDimensionCount(1, TensorAxis::RightAligned);
}

// Reproject a tensor to the given axis arrangement.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ class DmlOperatorMatMul : public DmlOperator
// Initialize the output description while overriding the shape
m_outputTensorDescs[0] = CreateTensorDescFromOutput(kernelInfo, 0, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, outputShape);

// DirectML only supports ranks up to 4D for GEMM, and so leading dimensions must be clamped.
m_inputTensorDescs[0].EnsureMaximumDimensionCount(4, TensorAxis::RightAligned);
m_inputTensorDescs[1].EnsureMaximumDimensionCount(4, TensorAxis::RightAligned);
m_outputTensorDescs[0].EnsureMaximumDimensionCount(4, TensorAxis::RightAligned);

std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,11 +290,20 @@
}

// Add additional padding 1's to ensure the count is at least that large.
void TensorDesc::EnsureDimensionCount(uint32_t newDimensionCount, TensorAxis alignment)
void TensorDesc::EnsureMinimumDimensionCount(uint32_t minimumDimensionCount, TensorAxis alignment)
{
if (m_bufferTensorDesc.DimensionCount < newDimensionCount)
if (m_bufferTensorDesc.DimensionCount < minimumDimensionCount)
{
SetDimensionCount(newDimensionCount, alignment);
SetDimensionCount(minimumDimensionCount, alignment);
}
}

// Ensure the dimension count is less than or equal to the limit.
void TensorDesc::EnsureMaximumDimensionCount(uint32_t maximumDimensionCount, TensorAxis alignment)
{
if (m_bufferTensorDesc.DimensionCount > maximumDimensionCount)
{
SetDimensionCount(maximumDimensionCount, alignment);
}
}

Expand All @@ -313,7 +322,53 @@
int32_t fillOffset = oldDimensionCount;
int32_t fillCount = std::max(0, difference);

// alignment == TensorAxis::LeftAligned is the easy case.
// If shrinking the rank, fold dimensions into the first/last dimension.
// e.g. Folding 4D dimensions [2,3,4,5] to 3D right-aligned yield [6,4,5]
// e.g. 6D dimensions [2,3,4,5,6,7] to 3D left-aligned yield [1,2,840]
if (difference < 0 && newDimensionCount > 0)
{
uint32_t dimensionCountRemoved = -difference;
uint32_t dimensionCountFolded = dimensionCountRemoved + 1; // If 2 dimensions are removed, then 3 dimensions are folded into one.
uint32_t targetDimensionIndex;
uint32_t firstFoldedDimensionIndex;

// Determine the range to fold and which dimension to fold them into.
if (alignment == TensorAxis::RightAligned)
{
targetDimensionIndex = dimensionCountRemoved; // Fold extra dimensions into the first dimension of the new size.
firstFoldedDimensionIndex = 0;
}
else // alignment == TensorAxis::LeftAligned
{
targetDimensionIndex = newDimensionCount - 1; // Fold extra dimensions into the last dimension of the new size.
firstFoldedDimensionIndex = targetDimensionIndex;
}
auto sizeFoldBegin = &m_sizes[firstFoldedDimensionIndex];
auto sizeFoldEnd = &m_sizes[firstFoldedDimensionIndex + dimensionCountFolded];

// Ensure no stride broadcasting is lost during the fold, which would silently give incorrect results.
ML_CHECK_VALID_ARGUMENT(
m_bufferTensorDesc.Strides == nullptr ||
!HasBroadcastedDimensions(
{ sizeFoldBegin, sizeFoldEnd },
{ &m_strides[firstFoldedDimensionIndex], dimensionCountFolded }
)
);

m_sizes[targetDimensionIndex] = std::accumulate(sizeFoldBegin, sizeFoldEnd, 1u, std::multiplies<uint32_t>());

Check warning on line 358 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <functional> for multiplies<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp:358: Add #include <functional> for multiplies<> [build/include_what_you_use] [4]

// Update strides too (right alignment has no extra work).
if (alignment == TensorAxis::LeftAligned)
{
m_strides[targetDimensionIndex] = m_strides[oldDimensionCount - 1]; // Migrate the last stride to its new position.
}
// Ensure the target stride is at least 1, not 0, in case a dimension of size 1 was folded that had a stride
// of 0 (which might happen because a stride of 0 for dimension of size 1 is ignorable), and other dimensions
// were folded into the target too.
m_strides[targetDimensionIndex] = std::max(m_strides[targetDimensionIndex], 1u);
}

// Left alignment is the easy case (just truncate the end).
// Right alignment needs more work, shifting values over.
if (alignment == TensorAxis::RightAligned)
{
Expand All @@ -322,6 +377,8 @@
memmove(&m_sizes[fillCount], &m_sizes[oldDimensionCount - moveCount], sizeof(m_sizes[0]) * moveCount);
memmove(&m_strides[fillCount], &m_strides[oldDimensionCount - moveCount], sizeof(m_strides[0]) * moveCount);
}

// For any new dimensions, insert leading/trailing 1's for sizes and 0's for strides.
if (fillCount > 0)
{
std::fill(&m_sizes[fillOffset], &m_sizes[fillOffset] + fillCount, 1u);
Expand Down Expand Up @@ -375,3 +432,30 @@
GetDescendingPackedStrides({m_sizes, m_bufferTensorDesc.DimensionCount}, {m_strides, m_bufferTensorDesc.DimensionCount});
m_bufferTensorDesc.Strides = m_strides;
}

bool TensorDesc::HasBroadcastedDimensions(
gsl::span<const uint32_t> dimensions,
gsl::span<const uint32_t> strides
) noexcept
{
assert(dimensions.size() == strides.size());
for (uint32_t i = 0; i < dimensions.size(); ++i)
{
// Note logical dimensions of size 1 (even when stride is 0) are not considered broadcasted.
if (strides[i] == 0 && dimensions[i] != 1)
{
return true;
}
}
return false;
}

bool TensorDesc::HasBroadcastedDimensions() const noexcept
{
return IsValid()
&& m_bufferTensorDesc.Strides != nullptr
&& HasBroadcastedDimensions(
{ m_sizes, m_bufferTensorDesc.DimensionCount },
{ m_strides, m_bufferTensorDesc.DimensionCount }
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,18 @@ namespace Dml
inline bool IsValid() const noexcept { return m_tensorType != DML_TENSOR_TYPE_INVALID; }
inline uint32_t GetDimensionCount() const { return m_bufferTensorDesc.DimensionCount; }
void SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignment);
void EnsureDimensionCount(uint32_t newDimensionCount, TensorAxis alignment);
void EnsureMinimumDimensionCount(uint32_t newDimensionCount, TensorAxis alignment);
void EnsureMaximumDimensionCount(uint32_t maximumDimensionCount, TensorAxis alignment);

gsl::span<const uint32_t> GetSizes() const noexcept { return { m_sizes, m_sizes + m_bufferTensorDesc.DimensionCount }; }
gsl::span<const uint32_t> GetStrides() const noexcept;
void SetStrides(gsl::span<const uint32_t> strides);
void EnsureStridesExist() noexcept;
bool HasBroadcastedDimensions() const noexcept;
static bool HasBroadcastedDimensions(
gsl::span<const uint32_t> dimensions,
gsl::span<const uint32_t> strides
) noexcept;

void SetDimensionsAndStrides(gsl::span<const uint32_t> sizes, gsl::span<const uint32_t> strides);

Expand Down
Loading