Skip to content

Commit 5534685

Browse files
pytorchbotmalfet
andauthored
[MPS] Reimplement tri[ul] as Metal shaders (pytorch#158867)
[MPS] Reimplement `tri[ul]` as Metal shaders (pytorch#157179) And add in-place flavor, as it is currently broken for non-contig tensors Pull Request resolved: pytorch#157179 Approved by: https://github.com/dcci (cherry picked from commit a1e4f1f) Co-authored-by: Nikita Shulga <[email protected]>
1 parent d19e08d commit 5534685

File tree

4 files changed

+157
-85
lines changed

4 files changed

+157
-85
lines changed

aten/src/ATen/native/mps/kernels/TriangularOps.metal

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,119 @@
11
#include <metal_stdlib>
2+
23
using namespace metal;
4+
5+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triu/tril ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
6+
template <bool upper>
7+
inline bool triul_mask(int row, int col, int k);
8+
template <>
9+
inline bool triul_mask<true>(int row, int col, int k) {
10+
return col - row >= k;
11+
}
12+
template <>
13+
inline bool triul_mask<false>(int row, int col, int k) {
14+
return col - row <= k;
15+
}
16+
17+
template <typename IndexType>
18+
inline IndexType compute_offs(
19+
constant IndexType* strides,
20+
constant uint* sizes,
21+
uint3 pos,
22+
int ndim) {
23+
auto offs = pos.x * strides[0] + pos.y * strides[1];
24+
if (ndim < 4) {
25+
return ndim == 3 ? offs + pos.z * strides[2] : offs;
26+
}
27+
auto idx = pos.z;
28+
for (int i = 2; i < ndim; ++i) {
29+
offs += strides[i] * (idx % sizes[i]);
30+
idx /= sizes[i];
31+
}
32+
return offs;
33+
}
34+
35+
template <typename T, typename IndexType, bool upper>
36+
kernel void triul_inplace(
37+
device T* self,
38+
constant IndexType* strides,
39+
constant uint* sizes,
40+
constant int2& k_ndim,
41+
uint3 pos [[thread_position_in_grid]]) {
42+
if (triul_mask<upper>(pos.y, pos.x, k_ndim.x)) {
43+
return;
44+
}
45+
auto offs = compute_offs(strides, sizes, pos, k_ndim.y);
46+
self[offs] = 0;
47+
}
48+
49+
template <typename T, typename IndexType, bool upper>
50+
kernel void triul(
51+
device T* out,
52+
device T* inp,
53+
constant IndexType* out_strides,
54+
constant IndexType* inp_strides,
55+
constant uint* sizes,
56+
constant int2& k_ndim,
57+
uint3 pos [[thread_position_in_grid]]) {
58+
auto out_offs = compute_offs(out_strides, sizes, pos, k_ndim.y);
59+
if (!triul_mask<upper>(pos.y, pos.x, k_ndim.x)) {
60+
out[out_offs] = 0;
61+
return;
62+
}
63+
auto inp_offs = compute_offs(inp_strides, sizes, pos, k_ndim.y);
64+
out[out_offs] = inp[inp_offs];
65+
}
66+
67+
#define INSTANTIATE_TRIUL_KERNELS(DTYPE, IDX_TYPE) \
68+
template [[host_name("triu_inplace_" #IDX_TYPE "_" #DTYPE)]] kernel void \
69+
triul_inplace<DTYPE, IDX_TYPE, true>( \
70+
device DTYPE * self, \
71+
constant IDX_TYPE * strides, \
72+
constant uint * sizes, \
73+
constant int2 & k_ndim, \
74+
uint3 pos [[thread_position_in_grid]]); \
75+
template [[host_name("tril_inplace_" #IDX_TYPE "_" #DTYPE)]] kernel void \
76+
triul_inplace<DTYPE, IDX_TYPE, false>( \
77+
device DTYPE * self, \
78+
constant IDX_TYPE * strides, \
79+
constant uint * sizes, \
80+
constant int2 & k_ndim, \
81+
uint3 pos [[thread_position_in_grid]]); \
82+
template [[host_name("triu_" #IDX_TYPE "_" #DTYPE)]] kernel void \
83+
triul<DTYPE, IDX_TYPE, true>( \
84+
device DTYPE * out, \
85+
device DTYPE * inp, \
86+
constant IDX_TYPE * out_strides, \
87+
constant IDX_TYPE * inp_strides, \
88+
constant uint * sizes, \
89+
constant int2 & k_ndim, \
90+
uint3 pos [[thread_position_in_grid]]); \
91+
template [[host_name("tril_" #IDX_TYPE "_" #DTYPE)]] kernel void \
92+
triul<DTYPE, IDX_TYPE, false>( \
93+
device DTYPE * out, \
94+
device DTYPE * inp, \
95+
constant IDX_TYPE * out_strides, \
96+
constant IDX_TYPE * inp_strides, \
97+
constant uint * sizes, \
98+
constant int2 & k_ndim, \
99+
uint3 pos [[thread_position_in_grid]])
100+
101+
INSTANTIATE_TRIUL_KERNELS(float, int);
102+
INSTANTIATE_TRIUL_KERNELS(half, int);
103+
#if __METAL_VERSION__ >= 310
104+
INSTANTIATE_TRIUL_KERNELS(bfloat, int);
105+
#endif
106+
107+
INSTANTIATE_TRIUL_KERNELS(float2, int);
108+
INSTANTIATE_TRIUL_KERNELS(half2, int);
109+
110+
INSTANTIATE_TRIUL_KERNELS(long, int);
111+
INSTANTIATE_TRIUL_KERNELS(int, int);
112+
INSTANTIATE_TRIUL_KERNELS(short, int);
113+
INSTANTIATE_TRIUL_KERNELS(char, int);
114+
INSTANTIATE_TRIUL_KERNELS(uchar, int);
115+
INSTANTIATE_TRIUL_KERNELS(bool, int);
116+
3117
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangle ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4118

5119
// To find the max integer that does not exceed the root of an int64_t variable,

aten/src/ATen/native/mps/operations/TriangularOps.mm

Lines changed: 36 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <ATen/native/LinearAlgebraUtils.h>
66
#include <ATen/native/TensorFactories.h>
77
#include <ATen/native/mps/OperationUtils.h>
8+
#include <fmt/format.h>
89

910
#ifndef AT_PER_OPERATOR_HEADERS
1011
#include <ATen/Functions.h>
@@ -26,101 +27,53 @@
2627
#include <ATen/native/mps/TriangularOps_metallib.h>
2728
#endif
2829

29-
TORCH_IMPL_FUNC(triu_mps_out)
30-
(const Tensor& self, int64_t k, const Tensor& output) {
31-
using namespace mps;
32-
using CachedGraph = MPSUnaryCachedGraph;
33-
34-
if (self.numel() == 0) {
35-
return;
36-
}
37-
auto stream = getCurrentMPSStream();
38-
39-
@autoreleasepool {
40-
std::string key = "triu_mps_out" + mps::getTensorsStringKey({self}) + ":" + std::to_string(k);
41-
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
42-
MPSGraphTensor* outputTensor = nil;
43-
auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
44-
45-
auto minusOneTensor = [mpsGraph constantWithScalar:-1 dataType:MPSDataTypeInt32];
46-
47-
if (k > 0) {
48-
auto diagMinusOneTensor = [mpsGraph constantWithScalar:(k - 1) dataType:MPSDataTypeInt32];
49-
auto onesTensor = [mpsGraph constantWithScalar:1 shape:inputTensor.shape dataType:MPSDataTypeInt32];
50-
auto maskTensor = [mpsGraph bandPartWithTensor:onesTensor
51-
numLowerTensor:minusOneTensor
52-
numUpperTensor:diagMinusOneTensor
53-
name:nil];
54-
outputTensor = [mpsGraph selectWithPredicateTensor:maskTensor
55-
truePredicateTensor:[mpsGraph constantWithScalar:0 dataType:inputTensor.dataType]
56-
falsePredicateTensor:inputTensor
57-
name:nil];
58-
} else {
59-
auto minusDiagTensor = [mpsGraph constantWithScalar:(-k) dataType:MPSDataTypeInt32];
60-
outputTensor = [mpsGraph bandPartWithTensor:inputTensor
61-
numLowerTensor:minusDiagTensor
62-
numUpperTensor:minusOneTensor
63-
name:nil];
64-
}
65-
66-
newCachedGraph->inputTensor_ = inputTensor;
67-
newCachedGraph->outputTensor_ = outputTensor;
68-
});
69-
70-
auto selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
71-
auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
72-
runMPSGraph(stream, cachedGraph->graph(), dictionaryFromPlaceholders(selfPlaceholder), outputPlaceholder);
30+
template <typename T>
31+
static std::vector<T> reverse_array(const IntArrayRef& arr) {
32+
std::vector<T> rc(arr.size());
33+
for (const auto& i : c10::irange(arr.size())) {
34+
rc[i] = arr[arr.size() - 1 - i];
7335
}
36+
return rc;
7437
}
7538

76-
TORCH_IMPL_FUNC(tril_mps_out)
77-
(const Tensor& self, int64_t k, const Tensor& output) {
39+
static void triu_tril_impl(const Tensor& self, int64_t k, const Tensor& out, const std::string& name) {
7840
using namespace mps;
79-
using CachedGraph = MPSUnaryCachedGraph;
80-
8141
if (self.numel() == 0) {
8242
return;
8343
}
84-
44+
auto sizes = reverse_array<uint32_t>(self.sizes());
45+
auto inp_strides = reverse_array<int32_t>(self.strides());
46+
auto out_strides = reverse_array<int32_t>(out.strides());
47+
std::array<int, 2> k_ndim = {int(k), int(self.ndimension())};
48+
const bool inplace = self.is_same(out);
49+
const auto kernel_name =
50+
fmt::format("{}{}_{}_{}", name, inplace ? "_inplace" : "", "int", scalarToMetalTypeString(self));
51+
auto triuPSO = lib.getPipelineStateForFunc(kernel_name);
52+
uint32_t max_threads_per_group = [triuPSO maxTotalThreadsPerThreadgroup];
8553
auto stream = getCurrentMPSStream();
86-
87-
@autoreleasepool {
88-
std::string key = "tril_mps_out" + mps::getTensorsStringKey({self}) + ":" + std::to_string(k);
89-
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
90-
MPSGraphTensor* outputTensor = nil;
91-
92-
auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
93-
auto minusOneTensor = [mpsGraph constantWithScalar:-1 dataType:MPSDataTypeInt32];
94-
95-
if (k >= 0) {
96-
auto diagTensor = [mpsGraph constantWithScalar:k dataType:MPSDataTypeInt32];
97-
outputTensor = [mpsGraph bandPartWithTensor:inputTensor
98-
numLowerTensor:minusOneTensor
99-
numUpperTensor:diagTensor
100-
name:nil];
54+
dispatch_sync_with_rethrow(stream->queue(), ^() {
55+
@autoreleasepool {
56+
auto computeEncoder = stream->commandEncoder();
57+
[computeEncoder setComputePipelineState:triuPSO];
58+
if (inplace) {
59+
mtl_setArgs(computeEncoder, self, inp_strides, sizes, k_ndim);
10160
} else {
102-
auto negDiagMinusOneTensor = [mpsGraph constantWithScalar:(-k - 1) dataType:MPSDataTypeInt32];
103-
auto complementTensor = [mpsGraph bandPartWithTensor:inputTensor
104-
numLowerTensor:negDiagMinusOneTensor
105-
numUpperTensor:minusOneTensor
106-
name:nil];
107-
auto zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:getMPSDataType(self)];
108-
auto mask = [mpsGraph equalWithPrimaryTensor:complementTensor secondaryTensor:zeroTensor name:nil];
109-
outputTensor = [mpsGraph selectWithPredicateTensor:mask
110-
truePredicateTensor:inputTensor
111-
falsePredicateTensor:zeroTensor
112-
name:nil];
61+
mtl_setArgs(computeEncoder, out, self, out_strides, inp_strides, sizes, k_ndim);
11362
}
63+
[computeEncoder dispatchThreads:MTLSizeMake(sizes[0], sizes[1], self.numel() / (sizes[0] * sizes[1]))
64+
threadsPerThreadgroup:MTLSizeMake(std::min(max_threads_per_group, sizes[0]), 1, 1)];
65+
}
66+
});
67+
}
11468

115-
newCachedGraph->inputTensor_ = inputTensor;
116-
newCachedGraph->outputTensor_ = outputTensor;
117-
});
118-
119-
auto selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
120-
auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
69+
TORCH_IMPL_FUNC(triu_mps_out)
70+
(const Tensor& self, int64_t k, const Tensor& output) {
71+
triu_tril_impl(self, k, output, "triu");
72+
}
12173

122-
runMPSGraph(stream, cachedGraph->graph(), dictionaryFromPlaceholders(selfPlaceholder), outputPlaceholder);
123-
}
74+
TORCH_IMPL_FUNC(tril_mps_out)
75+
(const Tensor& self, int64_t k, const Tensor& output) {
76+
triu_tril_impl(self, k, output, "tril");
12477
}
12578

12679
Tensor tril_indices_mps(int64_t row,

test/test_mps.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7146,6 +7146,11 @@ def helper(shape, diag=0):
71467146
helper((2, 8, 4, 5), diag=-1)
71477147
helper((2, 8, 4, 5), diag=-2)
71487148
helper((2, 8, 4, 5), diag=-3)
7149+
# Test inplace
7150+
x_mps = torch.arange(9.0, device='mps').reshape(3, 3).t().triu()
7151+
x_cpu = torch.arange(9.0, device='cpu').reshape(3, 3).t().triu()
7152+
self.assertEqual(x_cpu, x_mps)
7153+
self.assertEqual(x_cpu.stride(), x_mps.stride())
71497154

71507155
# Test inverse
71517156
def test_inverse(self):

torch/testing/_internal/common_mps.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ def mps_ops_modifier(
157157
"tensor_split",
158158
"transpose",
159159
"transpose_copy",
160+
"tril",
161+
"triu",
160162
"true_divide",
161163
"T",
162164
"unbind",
@@ -283,8 +285,6 @@ def mps_ops_modifier(
283285
"trace",
284286
"trapz",
285287
"trapezoid",
286-
"tril",
287-
"triu",
288288
"vstack",
289289
"where",
290290
"byte",

0 commit comments

Comments
 (0)