|
5 | 5 | #include <ATen/native/LinearAlgebraUtils.h>
|
6 | 6 | #include <ATen/native/TensorFactories.h>
|
7 | 7 | #include <ATen/native/mps/OperationUtils.h>
|
| 8 | +#include <fmt/format.h> |
8 | 9 |
|
9 | 10 | #ifndef AT_PER_OPERATOR_HEADERS
|
10 | 11 | #include <ATen/Functions.h>
|
|
26 | 27 | #include <ATen/native/mps/TriangularOps_metallib.h>
|
27 | 28 | #endif
|
28 | 29 |
|
29 |
| -TORCH_IMPL_FUNC(triu_mps_out) |
30 |
| -(const Tensor& self, int64_t k, const Tensor& output) { |
31 |
| - using namespace mps; |
32 |
| - using CachedGraph = MPSUnaryCachedGraph; |
33 |
| - |
34 |
| - if (self.numel() == 0) { |
35 |
| - return; |
36 |
| - } |
37 |
| - auto stream = getCurrentMPSStream(); |
38 |
| - |
39 |
| - @autoreleasepool { |
40 |
| - std::string key = "triu_mps_out" + mps::getTensorsStringKey({self}) + ":" + std::to_string(k); |
41 |
| - auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |
42 |
| - MPSGraphTensor* outputTensor = nil; |
43 |
| - auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); |
44 |
| - |
45 |
| - auto minusOneTensor = [mpsGraph constantWithScalar:-1 dataType:MPSDataTypeInt32]; |
46 |
| - |
47 |
| - if (k > 0) { |
48 |
| - auto diagMinusOneTensor = [mpsGraph constantWithScalar:(k - 1) dataType:MPSDataTypeInt32]; |
49 |
| - auto onesTensor = [mpsGraph constantWithScalar:1 shape:inputTensor.shape dataType:MPSDataTypeInt32]; |
50 |
| - auto maskTensor = [mpsGraph bandPartWithTensor:onesTensor |
51 |
| - numLowerTensor:minusOneTensor |
52 |
| - numUpperTensor:diagMinusOneTensor |
53 |
| - name:nil]; |
54 |
| - outputTensor = [mpsGraph selectWithPredicateTensor:maskTensor |
55 |
| - truePredicateTensor:[mpsGraph constantWithScalar:0 dataType:inputTensor.dataType] |
56 |
| - falsePredicateTensor:inputTensor |
57 |
| - name:nil]; |
58 |
| - } else { |
59 |
| - auto minusDiagTensor = [mpsGraph constantWithScalar:(-k) dataType:MPSDataTypeInt32]; |
60 |
| - outputTensor = [mpsGraph bandPartWithTensor:inputTensor |
61 |
| - numLowerTensor:minusDiagTensor |
62 |
| - numUpperTensor:minusOneTensor |
63 |
| - name:nil]; |
64 |
| - } |
65 |
| - |
66 |
| - newCachedGraph->inputTensor_ = inputTensor; |
67 |
| - newCachedGraph->outputTensor_ = outputTensor; |
68 |
| - }); |
69 |
| - |
70 |
| - auto selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); |
71 |
| - auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); |
72 |
| - runMPSGraph(stream, cachedGraph->graph(), dictionaryFromPlaceholders(selfPlaceholder), outputPlaceholder); |
| 30 | +template <typename T> |
| 31 | +static std::vector<T> reverse_array(const IntArrayRef& arr) { |
| 32 | + std::vector<T> rc(arr.size()); |
| 33 | + for (const auto& i : c10::irange(arr.size())) { |
| 34 | + rc[i] = arr[arr.size() - 1 - i]; |
73 | 35 | }
|
| 36 | + return rc; |
74 | 37 | }
|
75 | 38 |
|
76 |
| -TORCH_IMPL_FUNC(tril_mps_out) |
77 |
| -(const Tensor& self, int64_t k, const Tensor& output) { |
| 39 | +static void triu_tril_impl(const Tensor& self, int64_t k, const Tensor& out, const std::string& name) { |
78 | 40 | using namespace mps;
|
79 |
| - using CachedGraph = MPSUnaryCachedGraph; |
80 |
| - |
81 | 41 | if (self.numel() == 0) {
|
82 | 42 | return;
|
83 | 43 | }
|
84 |
| - |
| 44 | + auto sizes = reverse_array<uint32_t>(self.sizes()); |
| 45 | + auto inp_strides = reverse_array<int32_t>(self.strides()); |
| 46 | + auto out_strides = reverse_array<int32_t>(out.strides()); |
| 47 | + std::array<int, 2> k_ndim = {int(k), int(self.ndimension())}; |
| 48 | + const bool inplace = self.is_same(out); |
| 49 | + const auto kernel_name = |
| 50 | + fmt::format("{}{}_{}_{}", name, inplace ? "_inplace" : "", "int", scalarToMetalTypeString(self)); |
| 51 | + auto triuPSO = lib.getPipelineStateForFunc(kernel_name); |
| 52 | + uint32_t max_threads_per_group = [triuPSO maxTotalThreadsPerThreadgroup]; |
85 | 53 | auto stream = getCurrentMPSStream();
|
86 |
| - |
87 |
| - @autoreleasepool { |
88 |
| - std::string key = "tril_mps_out" + mps::getTensorsStringKey({self}) + ":" + std::to_string(k); |
89 |
| - auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |
90 |
| - MPSGraphTensor* outputTensor = nil; |
91 |
| - |
92 |
| - auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); |
93 |
| - auto minusOneTensor = [mpsGraph constantWithScalar:-1 dataType:MPSDataTypeInt32]; |
94 |
| - |
95 |
| - if (k >= 0) { |
96 |
| - auto diagTensor = [mpsGraph constantWithScalar:k dataType:MPSDataTypeInt32]; |
97 |
| - outputTensor = [mpsGraph bandPartWithTensor:inputTensor |
98 |
| - numLowerTensor:minusOneTensor |
99 |
| - numUpperTensor:diagTensor |
100 |
| - name:nil]; |
| 54 | + dispatch_sync_with_rethrow(stream->queue(), ^() { |
| 55 | + @autoreleasepool { |
| 56 | + auto computeEncoder = stream->commandEncoder(); |
| 57 | + [computeEncoder setComputePipelineState:triuPSO]; |
| 58 | + if (inplace) { |
| 59 | + mtl_setArgs(computeEncoder, self, inp_strides, sizes, k_ndim); |
101 | 60 | } else {
|
102 |
| - auto negDiagMinusOneTensor = [mpsGraph constantWithScalar:(-k - 1) dataType:MPSDataTypeInt32]; |
103 |
| - auto complementTensor = [mpsGraph bandPartWithTensor:inputTensor |
104 |
| - numLowerTensor:negDiagMinusOneTensor |
105 |
| - numUpperTensor:minusOneTensor |
106 |
| - name:nil]; |
107 |
| - auto zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:getMPSDataType(self)]; |
108 |
| - auto mask = [mpsGraph equalWithPrimaryTensor:complementTensor secondaryTensor:zeroTensor name:nil]; |
109 |
| - outputTensor = [mpsGraph selectWithPredicateTensor:mask |
110 |
| - truePredicateTensor:inputTensor |
111 |
| - falsePredicateTensor:zeroTensor |
112 |
| - name:nil]; |
| 61 | + mtl_setArgs(computeEncoder, out, self, out_strides, inp_strides, sizes, k_ndim); |
113 | 62 | }
|
| 63 | + [computeEncoder dispatchThreads:MTLSizeMake(sizes[0], sizes[1], self.numel() / (sizes[0] * sizes[1])) |
| 64 | + threadsPerThreadgroup:MTLSizeMake(std::min(max_threads_per_group, sizes[0]), 1, 1)]; |
| 65 | + } |
| 66 | + }); |
| 67 | +} |
114 | 68 |
|
115 |
| - newCachedGraph->inputTensor_ = inputTensor; |
116 |
| - newCachedGraph->outputTensor_ = outputTensor; |
117 |
| - }); |
118 |
| - |
119 |
| - auto selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); |
120 |
| - auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); |
| 69 | +TORCH_IMPL_FUNC(triu_mps_out) |
| 70 | +(const Tensor& self, int64_t k, const Tensor& output) { |
| 71 | + triu_tril_impl(self, k, output, "triu"); |
| 72 | +} |
121 | 73 |
|
122 |
| - runMPSGraph(stream, cachedGraph->graph(), dictionaryFromPlaceholders(selfPlaceholder), outputPlaceholder); |
123 |
| - } |
| 74 | +TORCH_IMPL_FUNC(tril_mps_out) |
| 75 | +(const Tensor& self, int64_t k, const Tensor& output) { |
| 76 | + triu_tril_impl(self, k, output, "tril"); |
124 | 77 | }
|
125 | 78 |
|
126 | 79 | Tensor tril_indices_mps(int64_t row,
|
|
0 commit comments