-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdevice_utils.py
More file actions
150 lines (113 loc) · 3.97 KB
/
device_utils.py
File metadata and controls
150 lines (113 loc) · 3.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
import random
from typing import Union
import torch
try:
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.xla_backend as xb
except ImportError:
xm = None
xr = None
xb = None
def get_current_device() -> torch.device:
global __current_device
try:
return __current_device
except NameError:
if xm is not None:
__current_device = xm.xla_device()
elif torch.cuda.is_available():
local_rank = int(os.getenv("LOCAL_RANK", 0))
__current_device = torch.device(f'cuda:{local_rank}')
torch.cuda.set_device(__current_device)
else:
device = os.getenv("DEFAULT_DEVICE", "cpu")
__current_device = torch.device(device)
return __current_device
def get_current_device_type() -> str:
global __current_device_type
try:
return __current_device_type
except NameError:
if xm is not None:
__current_device_type = "xla"
elif torch.cuda.is_available():
__current_device_type = "cuda"
else:
__current_device_type = os.getenv("DEFAULT_DEVICE_TYPE", "cpu")
return __current_device_type
def get_local_device_count() -> int:
device_count = 1
if xr is not None:
device_count = xr.global_device_count()
elif torch.cuda.is_available():
device_count = torch.cuda.device_count()
return device_count
def get_distributed_backend(backend=None) -> str:
if xm is not None:
backend = "xla"
elif torch.cuda.is_available():
backend = backend if backend is not None else "nccl"
else:
backend = backend if backend is not None else "gloo"
return backend
def get_distributed_init_method() -> str:
if xm is not None:
init_method = 'xla://'
else:
init_method = "env://"
return init_method
def get_current_rng_state() -> Union[torch.Tensor, int]:
if torch.cuda.is_available():
rng_state = torch.cuda.get_rng_state(device=get_current_device())
elif xm:
rng_state = xm.get_rng_state(device=get_current_device())
else:
rng_state = torch.get_rng_state()
return rng_state
def set_manual_seed(seed: int):
random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
elif xm is not None:
xm.set_rng_state(seed, device=get_current_device())
else:
torch.manual_seed(seed)
def set_current_rng_state(new_state):
if torch.cuda.is_available():
new_state = new_state.type(torch.ByteTensor)
torch.cuda.set_rng_state(new_state, device=get_current_device())
elif xm is not None:
new_state = int(new_state)
xm.set_rng_state(new_state, device=get_current_device())
else:
new_state = new_state.type(torch.ByteTensor)
torch.set_rng_state(new_state)
if xb:
def make_send_channel_id_impl(self, dst_rank, tag):
return int(dst_rank)*2
def make_recv_channel_id_impl(self, src_rank, tag):
return int(src_rank)*3
xb.ProcessGroupXla.make_send_channel_id = make_send_channel_id_impl
xb.ProcessGroupXla.make_recv_channel_id = make_recv_channel_id_impl
def parse_dtype(dtype: str):
d, t = dtype.rsplit(".", 1)
assert d in ['torch', 'torch.cuda', 'torch.xla']
assert t in [ 'FloatTensor', 'HalfTensor', 'BFloat16Tensor']
if t == "FloatTensor":
dtype = torch.float32
elif t == "HalfTensor":
dtype = torch.float16
elif t == "BFloat16Tensor":
dtype = torch.bfloat16
device = torch.device("cpu") if d == "torch" else get_current_device()
return device, dtype
def is_bf16_supported():
if get_current_device_type() == 'xla':
bf16_supported = bool(int(os.getenv('XLA_USE_BF16', 0)))
elif get_current_device_type() == 'cuda':
bf16_supported = torch.cuda.is_bf16_supported()
else:
bf16_supported = False
return bf16_supported