Skip to content

Commit 0f1a88c

Browse files
fffrogpytorchmergebot
authored andcommitted
Make Context to be Device-agnostic Step by Step (2/N) (pytorch#136526)
---- - add new method(getDefaultGenerator, getNewGenerator) into AcceleratorHooksInterface Pull Request resolved: pytorch#136526 Approved by: https://github.com/ezyang, https://github.com/EikanWang
1 parent cca34be commit 0f1a88c

File tree

20 files changed

+70
-82
lines changed

20 files changed

+70
-82
lines changed

aten/src/ATen/Context.h

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,21 +44,9 @@ class TORCH_API Context {
4444

4545
if (device_type == at::kCPU) {
4646
return at::detail::getDefaultCPUGenerator();
47-
} else if (device_type == at::kCUDA) {
48-
return at::detail::getCUDAHooks().getDefaultCUDAGenerator(device.index());
49-
} else if (device_type == at::kMPS) {
50-
return at::detail::getMPSHooks().getDefaultMPSGenerator();
51-
} else if (device_type == at::kXPU) {
52-
return at::detail::getXPUHooks().getDefaultXPUGenerator(device.index());
53-
} else if (device_type == at::kIPU) {
54-
return at::detail::getIPUHooks().getDefaultIPUGenerator(device.index());
55-
} else if (device_type == at::kHPU) {
56-
return at::detail::getHPUHooks().getDefaultHPUGenerator(device.index());
57-
} else if (device_type == at::kPrivateUse1) {
58-
return at::detail::getPrivateUse1Hooks().getDefaultGenerator(
59-
device.index());
6047
} else {
61-
AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled.");
48+
return getAcceleratorHooksInterface(device_type)
49+
.getDefaultGenerator(device.index());
6250
}
6351
}
6452

aten/src/ATen/cuda/detail/CUDAHooks.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ void CUDAHooks::init() const {
102102
#endif
103103
}
104104

105-
const Generator& CUDAHooks::getDefaultCUDAGenerator(DeviceIndex device_index) const {
105+
const Generator& CUDAHooks::getDefaultGenerator(DeviceIndex device_index) const {
106106
return at::cuda::detail::getDefaultCUDAGenerator(device_index);
107107
}
108108

aten/src/ATen/cuda/detail/CUDAHooks.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ struct CUDAHooks : public at::CUDAHooksInterface {
2121
void init() const override;
2222
Device getDeviceFromPtr(void* data) const override;
2323
bool isPinnedPtr(const void* data) const override;
24-
const Generator& getDefaultCUDAGenerator(DeviceIndex device_index = -1) const override;
24+
const Generator& getDefaultGenerator(
25+
DeviceIndex device_index = -1) const override;
2526
bool hasCUDA() const override;
2627
bool hasMAGMA() const override;
2728
bool hasCuDNN() const override;

aten/src/ATen/detail/AcceleratorHooksInterface.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
#pragma once
22

3+
#include <ATen/core/Generator.h>
4+
5+
#include <c10/core/Allocator.h>
36
#include <c10/core/Device.h>
47
#include <c10/core/Stream.h>
5-
#include <c10/core/Allocator.h>
8+
69
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
10+
711
namespace at {
812

913
// AcceleratorHooksInterface is a shared interface provided by all
@@ -58,7 +62,18 @@ struct TORCH_API AcceleratorHooksInterface {
5862
virtual Device getDeviceFromPtr(void* data) const {
5963
TORCH_CHECK(false, "Backend doesn't support getDeviceFromPtr()");
6064
}
65+
66+
virtual const Generator& getDefaultGenerator(
67+
[[maybe_unused]] DeviceIndex device_index = -1) const {
68+
TORCH_CHECK(false, "Backend doesn`t support getDefaultGenerator()");
69+
}
70+
71+
virtual Generator getNewGenerator(
72+
[[maybe_unused]] DeviceIndex device_index = -1) const {
73+
TORCH_CHECK(false, "Backend doesn`t support getNewGenerator()");
74+
}
6175
};
6276

6377
} // namespace at
78+
6479
C10_DIAGNOSTIC_POP()

aten/src/ATen/detail/CUDAHooksInterface.h

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,13 @@
66

77
#include <ATen/detail/AcceleratorHooksInterface.h>
88

9-
// Forward-declares at::Generator and at::cuda::NVRTC
9+
// NB: Class must live in `at` due to limitations of Registry.h.
1010
namespace at {
11-
struct Generator;
11+
12+
// Forward-declares at::cuda::NVRTC
1213
namespace cuda {
1314
struct NVRTC;
1415
} // namespace cuda
15-
} // namespace at
16-
17-
// NB: Class must live in `at` due to limitations of Registry.h.
18-
namespace at {
1916

2017
#ifdef _MSC_VER
2118
constexpr const char* CUDA_HELP =
@@ -69,8 +66,8 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
6966
TORCH_CHECK(false, "Cannot initialize CUDA without ATen_cuda library. ", CUDA_HELP);
7067
}
7168

72-
virtual const Generator& getDefaultCUDAGenerator(
73-
[[maybe_unused]] DeviceIndex device_index = -1) const {
69+
const Generator& getDefaultGenerator(
70+
[[maybe_unused]] DeviceIndex device_index = -1) const override {
7471
TORCH_CHECK(
7572
false,
7673
"Cannot get default CUDA generator without ATen_cuda library. ",

aten/src/ATen/detail/HIPHooksInterface.h

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,13 @@
11
#pragma once
22

33
#include <c10/core/Allocator.h>
4-
#include <c10/core/GeneratorImpl.h>
54
#include <c10/util/Exception.h>
6-
75
#include <c10/util/Registry.h>
86

97
#include <ATen/detail/AcceleratorHooksInterface.h>
108

119
#include <memory>
1210

13-
namespace at {
14-
class Context;
15-
}
16-
1711
// NB: Class must live in `at` due to limitations of Registry.h.
1812
namespace at {
1913

@@ -30,8 +24,9 @@ struct TORCH_API HIPHooksInterface : AcceleratorHooksInterface {
3024
TORCH_CHECK(false, "Cannot initialize HIP without ATen_hip library.");
3125
}
3226

33-
virtual std::unique_ptr<c10::GeneratorImpl> initHIPGenerator(Context*) const {
34-
AT_ERROR("Cannot initialize HIP generator without ATen_hip library.");
27+
const Generator& getDefaultGenerator(
28+
[[maybe_unused]] DeviceIndex device_index = -1) const override {
29+
TORCH_CHECK(false, "Cannot initialize HIP without ATen_hip library.");
3530
}
3631

3732
virtual bool hasHIP() const {
@@ -50,10 +45,6 @@ struct TORCH_API HIPHooksInterface : AcceleratorHooksInterface {
5045
TORCH_CHECK(false, "Pinned memory requires HIP.");
5146
}
5247

53-
virtual void registerHIPTypes(Context*) const {
54-
AT_ERROR("Cannot registerHIPTypes() without ATen_hip library.");
55-
}
56-
5748
virtual int getNumGPUs() const {
5849
return 0;
5950
}

aten/src/ATen/detail/IPUHooksInterface.h

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#pragma once
22

3-
#include <ATen/core/Generator.h>
43
#include <ATen/detail/AcceleratorHooksInterface.h>
54

65
#include <c10/core/Allocator.h>
@@ -9,7 +8,7 @@
98

109
namespace at {
1110

12-
struct TORCH_API IPUHooksInterface: AcceleratorHooksInterface {
11+
struct TORCH_API IPUHooksInterface : AcceleratorHooksInterface {
1312
~IPUHooksInterface() override = default;
1413

1514
void init() const override {
@@ -21,16 +20,14 @@ struct TORCH_API IPUHooksInterface: AcceleratorHooksInterface {
2120
return false;
2221
}
2322

24-
virtual const Generator& getDefaultIPUGenerator(
25-
DeviceIndex device_index [[maybe_unused]] = -1) const {
26-
AT_ERROR(
27-
"Cannot get the default IPU generator: the IPU backend is not "
28-
"available.");
23+
const Generator& getDefaultGenerator(
24+
[[maybe_unused]] DeviceIndex device_index = -1) const override {
25+
TORCH_CHECK(false, "Cannot initialize IPU without ATen_ipu library.");
2926
}
3027

31-
virtual Generator newIPUGenerator(DeviceIndex device_index [[maybe_unused]] = -1) const {
32-
AT_ERROR(
33-
"Cannot create a new IPU generator: the IPU backend is not available.");
28+
Generator getNewGenerator(
29+
DeviceIndex device_index [[maybe_unused]] = -1) const override {
30+
TORCH_CHECK(false, "Cannot initialize IPU without ATen_ipu library.");
3431
}
3532
};
3633

aten/src/ATen/detail/MPSHooksInterface.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
#pragma once
44

5-
#include <c10/core/Allocator.h>
6-
#include <ATen/core/Generator.h>
75
#include <ATen/detail/AcceleratorHooksInterface.h>
6+
7+
#include <c10/core/Allocator.h>
88
#include <c10/util/Exception.h>
99
#include <c10/util/Registry.h>
1010

@@ -31,7 +31,8 @@ struct TORCH_API MPSHooksInterface : AcceleratorHooksInterface {
3131
virtual bool isOnMacOSorNewer(unsigned major = 13, unsigned minor = 0) const {
3232
FAIL_MPSHOOKS_FUNC(__func__);
3333
}
34-
virtual const Generator& getDefaultMPSGenerator() const {
34+
const Generator& getDefaultGenerator(
35+
[[maybe_unused]] DeviceIndex device_index = -1) const override {
3536
FAIL_MPSHOOKS_FUNC(__func__);
3637
}
3738
virtual Allocator* getMPSDeviceAllocator() const {

aten/src/ATen/detail/PrivateUse1HooksInterface.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
#pragma once
22

3-
#include <ATen/core/Generator.h>
43
#include <ATen/detail/AcceleratorHooksInterface.h>
54
#include <c10/core/Allocator.h>
65
#include <c10/core/Device.h>
76
#include <c10/core/Storage.h>
87
#include <c10/util/Exception.h>
8+
99
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
10+
1011
namespace at {
1112

1213
struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface {
1314
~PrivateUse1HooksInterface() override = default;
14-
virtual const at::Generator& getDefaultGenerator(
15-
c10::DeviceIndex device_index) const {
15+
16+
const at::Generator& getDefaultGenerator(
17+
c10::DeviceIndex device_index) const override {
1618
TORCH_CHECK_NOT_IMPLEMENTED(
1719
false,
1820
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getDefaultGenerator`.");

aten/src/ATen/detail/XPUHooksInterface.h

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include <c10/util/Exception.h>
55
#include <c10/util/Registry.h>
66

7-
#include <ATen/core/Generator.h>
87
#include <ATen/detail/AcceleratorHooksInterface.h>
98

109
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
@@ -32,17 +31,17 @@ struct TORCH_API XPUHooksInterface : AcceleratorHooksInterface{
3231
TORCH_CHECK(false, "Cannot get XPU global device index without ATen_xpu library.");
3332
}
3433

35-
virtual Generator getXPUGenerator(
36-
[[maybe_unused]] DeviceIndex device_index = -1) const {
37-
TORCH_CHECK(false, "Cannot get XPU generator without ATen_xpu library.");
38-
}
39-
40-
virtual const Generator& getDefaultXPUGenerator(
41-
[[maybe_unused]] DeviceIndex device_index = -1) const {
34+
const Generator& getDefaultGenerator(
35+
[[maybe_unused]] DeviceIndex device_index = -1) const override {
4236
TORCH_CHECK(
4337
false, "Cannot get default XPU generator without ATen_xpu library.");
4438
}
4539

40+
Generator getNewGenerator(
41+
[[maybe_unused]] DeviceIndex device_index = -1) const override {
42+
TORCH_CHECK(false, "Cannot get XPU generator without ATen_xpu library.");
43+
}
44+
4645
virtual DeviceIndex getNumGPUs() const {
4746
return 0;
4847
}

0 commit comments

Comments
 (0)