|
5 | 5 | #include <ATen/native/LinearAlgebraUtils.h> |
6 | 6 | #include <ATen/native/TensorFactories.h> |
7 | 7 | #include <ATen/native/mps/OperationUtils.h> |
| 8 | +#include <fmt/format.h> |
8 | 9 |
|
9 | 10 | #ifndef AT_PER_OPERATOR_HEADERS |
10 | 11 | #include <ATen/Functions.h> |
|
26 | 27 | #include <ATen/native/mps/TriangularOps_metallib.h> |
27 | 28 | #endif |
28 | 29 |
|
29 | | -TORCH_IMPL_FUNC(triu_mps_out) |
30 | | -(const Tensor& self, int64_t k, const Tensor& output) { |
31 | | - using namespace mps; |
32 | | - using CachedGraph = MPSUnaryCachedGraph; |
33 | | - |
34 | | - if (self.numel() == 0) { |
35 | | - return; |
36 | | - } |
37 | | - auto stream = getCurrentMPSStream(); |
38 | | - |
39 | | - @autoreleasepool { |
40 | | - std::string key = "triu_mps_out" + mps::getTensorsStringKey({self}) + ":" + std::to_string(k); |
41 | | - auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |
42 | | - MPSGraphTensor* outputTensor = nil; |
43 | | - auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); |
44 | | - |
45 | | - auto minusOneTensor = [mpsGraph constantWithScalar:-1 dataType:MPSDataTypeInt32]; |
46 | | - |
47 | | - if (k > 0) { |
48 | | - auto diagMinusOneTensor = [mpsGraph constantWithScalar:(k - 1) dataType:MPSDataTypeInt32]; |
49 | | - auto onesTensor = [mpsGraph constantWithScalar:1 shape:inputTensor.shape dataType:MPSDataTypeInt32]; |
50 | | - auto maskTensor = [mpsGraph bandPartWithTensor:onesTensor |
51 | | - numLowerTensor:minusOneTensor |
52 | | - numUpperTensor:diagMinusOneTensor |
53 | | - name:nil]; |
54 | | - outputTensor = [mpsGraph selectWithPredicateTensor:maskTensor |
55 | | - truePredicateTensor:[mpsGraph constantWithScalar:0 dataType:inputTensor.dataType] |
56 | | - falsePredicateTensor:inputTensor |
57 | | - name:nil]; |
58 | | - } else { |
59 | | - auto minusDiagTensor = [mpsGraph constantWithScalar:(-k) dataType:MPSDataTypeInt32]; |
60 | | - outputTensor = [mpsGraph bandPartWithTensor:inputTensor |
61 | | - numLowerTensor:minusDiagTensor |
62 | | - numUpperTensor:minusOneTensor |
63 | | - name:nil]; |
64 | | - } |
65 | | - |
66 | | - newCachedGraph->inputTensor_ = inputTensor; |
67 | | - newCachedGraph->outputTensor_ = outputTensor; |
68 | | - }); |
69 | | - |
70 | | - auto selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); |
71 | | - auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); |
72 | | - runMPSGraph(stream, cachedGraph->graph(), dictionaryFromPlaceholders(selfPlaceholder), outputPlaceholder); |
| 30 | +template <typename T> |
| 31 | +static std::vector<T> reverse_array(const IntArrayRef& arr) { |
| 32 | + std::vector<T> rc(arr.size()); |
| 33 | + for (const auto& i : c10::irange(arr.size())) { |
| 34 | + rc[i] = arr[arr.size() - 1 - i]; |
73 | 35 | } |
| 36 | + return rc; |
74 | 37 | } |
75 | 38 |
|
76 | | -TORCH_IMPL_FUNC(tril_mps_out) |
77 | | -(const Tensor& self, int64_t k, const Tensor& output) { |
| 39 | +static void triu_tril_impl(const Tensor& self, int64_t k, const Tensor& out, const std::string& name) { |
78 | 40 | using namespace mps; |
79 | | - using CachedGraph = MPSUnaryCachedGraph; |
80 | | - |
81 | 41 | if (self.numel() == 0) { |
82 | 42 | return; |
83 | 43 | } |
84 | | - |
| 44 | + auto sizes = reverse_array<uint32_t>(self.sizes()); |
| 45 | + auto inp_strides = reverse_array<int32_t>(self.strides()); |
| 46 | + auto out_strides = reverse_array<int32_t>(out.strides()); |
| 47 | + std::array<int, 2> k_ndim = {int(k), int(self.ndimension())}; |
| 48 | + const bool inplace = self.is_same(out); |
| 49 | + const auto kernel_name = |
| 50 | + fmt::format("{}{}_{}_{}", name, inplace ? "_inplace" : "", "int", scalarToMetalTypeString(self)); |
| 51 | + auto triuPSO = lib.getPipelineStateForFunc(kernel_name); |
| 52 | + uint32_t max_threads_per_group = [triuPSO maxTotalThreadsPerThreadgroup]; |
85 | 53 | auto stream = getCurrentMPSStream(); |
86 | | - |
87 | | - @autoreleasepool { |
88 | | - std::string key = "tril_mps_out" + mps::getTensorsStringKey({self}) + ":" + std::to_string(k); |
89 | | - auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |
90 | | - MPSGraphTensor* outputTensor = nil; |
91 | | - |
92 | | - auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); |
93 | | - auto minusOneTensor = [mpsGraph constantWithScalar:-1 dataType:MPSDataTypeInt32]; |
94 | | - |
95 | | - if (k >= 0) { |
96 | | - auto diagTensor = [mpsGraph constantWithScalar:k dataType:MPSDataTypeInt32]; |
97 | | - outputTensor = [mpsGraph bandPartWithTensor:inputTensor |
98 | | - numLowerTensor:minusOneTensor |
99 | | - numUpperTensor:diagTensor |
100 | | - name:nil]; |
| 54 | + dispatch_sync_with_rethrow(stream->queue(), ^() { |
| 55 | + @autoreleasepool { |
| 56 | + auto computeEncoder = stream->commandEncoder(); |
| 57 | + [computeEncoder setComputePipelineState:triuPSO]; |
| 58 | + if (inplace) { |
| 59 | + mtl_setArgs(computeEncoder, self, inp_strides, sizes, k_ndim); |
101 | 60 | } else { |
102 | | - auto negDiagMinusOneTensor = [mpsGraph constantWithScalar:(-k - 1) dataType:MPSDataTypeInt32]; |
103 | | - auto complementTensor = [mpsGraph bandPartWithTensor:inputTensor |
104 | | - numLowerTensor:negDiagMinusOneTensor |
105 | | - numUpperTensor:minusOneTensor |
106 | | - name:nil]; |
107 | | - auto zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:getMPSDataType(self)]; |
108 | | - auto mask = [mpsGraph equalWithPrimaryTensor:complementTensor secondaryTensor:zeroTensor name:nil]; |
109 | | - outputTensor = [mpsGraph selectWithPredicateTensor:mask |
110 | | - truePredicateTensor:inputTensor |
111 | | - falsePredicateTensor:zeroTensor |
112 | | - name:nil]; |
| 61 | + mtl_setArgs(computeEncoder, out, self, out_strides, inp_strides, sizes, k_ndim); |
113 | 62 | } |
| 63 | + [computeEncoder dispatchThreads:MTLSizeMake(sizes[0], sizes[1], self.numel() / (sizes[0] * sizes[1])) |
| 64 | + threadsPerThreadgroup:MTLSizeMake(std::min(max_threads_per_group, sizes[0]), 1, 1)]; |
| 65 | + } |
| 66 | + }); |
| 67 | +} |
114 | 68 |
|
115 | | - newCachedGraph->inputTensor_ = inputTensor; |
116 | | - newCachedGraph->outputTensor_ = outputTensor; |
117 | | - }); |
118 | | - |
119 | | - auto selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); |
120 | | - auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); |
| 69 | +TORCH_IMPL_FUNC(triu_mps_out) |
| 70 | +(const Tensor& self, int64_t k, const Tensor& output) { |
| 71 | + triu_tril_impl(self, k, output, "triu"); |
| 72 | +} |
121 | 73 |
|
122 | | - runMPSGraph(stream, cachedGraph->graph(), dictionaryFromPlaceholders(selfPlaceholder), outputPlaceholder); |
123 | | - } |
| 74 | +TORCH_IMPL_FUNC(tril_mps_out) |
| 75 | +(const Tensor& self, int64_t k, const Tensor& output) { |
| 76 | + triu_tril_impl(self, k, output, "tril"); |
124 | 77 | } |
125 | 78 |
|
126 | 79 | Tensor tril_indices_mps(int64_t row, |
|
0 commit comments