@@ -5,7 +5,7 @@ namespace ops {
55
66namespace mps {
77
8- static const char * METAL_VISION = R"VISION_METAL(
8+ static at::native::mps::MetalShaderLibrary lib ( R"VISION_METAL(
99
1010#include <metal_atomic>
1111#include <metal_stdlib>
@@ -26,46 +26,15 @@ inline T ceil_div(T n, T m) {
2626 return (n + m - 1) / m;
2727}
2828
29- template <typename T>
30- inline void atomic_add_float( device T* data_ptr, const T val)
29+ inline void atomic_add_float(device float* data_ptr, const float val)
3130{
32- #if __METAL_VERSION__ >= 300
33- // atomic_float is supported in Metal 3 (macOS Ventura) onward.
34- device atomic_fetch_add_explicit((device atomic_float*) data_ptr, val, memory_order_relaxed);
35- #else
36- // Custom atomic addition implementation
37- // https://github.com/ShoYamanishi/AppleNumericalComputing/blob/053f06c1f5a831095c4bcc29aaf11366fce5231e/03_dot/metal/dot.metal#L447-L472
38- // https://forums.developer.nvidia.com/t/atomicadd-float-float-atomicmul-float-float/14639
39- // https://on-demand.gputechconf.com/gtc/2013/presentations/S3101-Atomic-Memory-Operations.pdf (See the last slide)
40-
41- // Create an atomic uint pointer for atomic transaction.
42- device atomic_uint* atom_var = (device atomic_uint*)data_ptr;
43- // Create necessary storage.
44- uint fetched_uint, assigning_uint;
45- T fetched_float, assigning_float;
46-
47- // Replace the value in atom_var with 0 and return the previous value in atom_var.
48- fetched_uint = atomic_exchange_explicit( atom_var, 0 /*desired*/, memory_order_relaxed);
49- // Read out the previous value as float.
50- fetched_float = *( (thread T*) &fetched_uint );
51-
52- // Do addition and represent the addition result in uint for atomic transaction.
53- assigning_float = fetched_float + val;
54- assigning_uint = *((thread uint*) &assigning_float);
55-
56- // atom_var should be 0 now, try to assign the addition result back to the atom_var (data_ptr).
57- while ((fetched_uint = atomic_exchange_explicit( atom_var, assigning_uint /*desired*/, memory_order_relaxed)) != 0) {
58- // If atom_var was not 0, i.e. fetched_uint != 0, it means that the data has been modified by other threads.
59- // Try to assign 0 and get the previously assigned addition result.
60- uint fetched_uint_again = atomic_exchange_explicit(atom_var, 0 /*desired*/, memory_order_relaxed);
61- T fetched_float_again = *( (thread T*) &fetched_uint_again );
62- // Re-add again
63- fetched_float = *((thread T*) &(fetched_uint));
64- // Previously assigned addition result + addition result from other threads.
65- assigning_float = fetched_float_again + fetched_float;
66- assigning_uint = *( (thread uint*) &assigning_float);
67- }
68- #endif
31+ atomic_fetch_add_explicit((device atomic_float*) data_ptr, val, memory_order_relaxed);
32+ }
33+
34+
35+ inline void atomic_add_float(device half* data_ptr, const half val)
36+ {
37+ atomic_fetch_add_explicit((device atomic_float*) data_ptr, static_cast<float>(val), memory_order_relaxed);
6938}
7039
7140template <typename T, typename integer_t>
@@ -1061,40 +1030,12 @@ REGISTER_PS_ROI_POOL_OP(half, int64_t);
10611030REGISTER_PS_ROI_POOL_BACKWARD_OP(float, int64_t);
10621031REGISTER_PS_ROI_POOL_BACKWARD_OP(half, int64_t);
10631032
1064- )VISION_METAL" ;
1065-
1066- static id<MTLLibrary> compileVisionOpsLibrary (id<MTLDevice> device) {
1067- static id<MTLLibrary> visionLibrary = nil;
1068- if (visionLibrary) {
1069- return visionLibrary;
1070- }
1071-
1072- NSError* error = nil;
1073- MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
1074- [options setLanguageVersion:MTLLanguageVersion2_3];
1075- visionLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_VISION encoding:NSASCIIStringEncoding]
1076- options:options
1077- error:&error];
1078- TORCH_CHECK (visionLibrary, " Failed to create metal vision library, error: " , [[error description] UTF8String]);
1079- return visionLibrary;
1080- }
1081-
1082- static id<MTLComputePipelineState> visionPipelineState (id<MTLDevice> device, const std::string& kernel) {
1083- static std::unordered_map<std::string, id<MTLComputePipelineState>> psoCache;
1084- id<MTLComputePipelineState> pso = psoCache[kernel];
1085- if (pso) {
1086- return pso;
1087- }
1088-
1089- NSError* error = nil;
1090- id<MTLLibrary> visionLib = compileVisionOpsLibrary (device);
1091- id<MTLFunction> visionFunc = [visionLib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str ()]];
1092- TORCH_CHECK (visionFunc, " Failed to create function state object for: " , kernel);
1093- pso = [device newComputePipelineStateWithFunction:visionFunc error:&error];
1094- TORCH_CHECK (pso, " Failed to created pipeline state object, error: " , [[error description] UTF8String]);
1033+ )VISION_METAL" );
10951034
1096- psoCache[kernel] = pso;
1097- return pso;
1035+ static id<MTLComputePipelineState> visionPipelineState (
1036+ id<MTLDevice> device,
1037+ const std::string& kernel) {
1038+ return lib.getPipelineStateForFunc (kernel);
10981039}
10991040
11001041} // namespace mps
0 commit comments