|
1 | 1 | import functools |
| 2 | +import gc |
2 | 3 | import importlib |
3 | 4 | import importlib.metadata |
4 | 5 | import inspect |
|
86 | 87 | ) from e |
87 | 88 | logger.info(f"torch_device overrode to {torch_device}") |
88 | 89 | else: |
89 | | - torch_device = "cuda" if torch.cuda.is_available() else "cpu" |
| 90 | + if torch.cuda.is_available(): |
| 91 | + torch_device = "cuda" |
| 92 | + elif torch.xpu.is_available(): |
| 93 | + torch_device = "xpu" |
| 94 | + else: |
| 95 | + torch_device = "cpu" |
90 | 96 | is_torch_higher_equal_than_1_12 = version.parse( |
91 | 97 | version.parse(torch.__version__).base_version |
92 | 98 | ) >= version.parse("1.12") |
@@ -1055,12 +1061,34 @@ def _is_torch_fp64_available(device): |
1055 | 1061 | # Guard these lookups for when Torch is not used - alternative accelerator support is for PyTorch |
1056 | 1062 | if is_torch_available(): |
1057 | 1063 | # Behaviour flags |
1058 | | - BACKEND_SUPPORTS_TRAINING = {"cuda": True, "cpu": True, "mps": False, "default": True} |
| 1064 | + BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "default": True} |
1059 | 1065 |
|
1060 | 1066 | # Function definitions |
1061 | | - BACKEND_EMPTY_CACHE = {"cuda": torch.cuda.empty_cache, "cpu": None, "mps": None, "default": None} |
1062 | | - BACKEND_DEVICE_COUNT = {"cuda": torch.cuda.device_count, "cpu": lambda: 0, "mps": lambda: 0, "default": 0} |
1063 | | - BACKEND_MANUAL_SEED = {"cuda": torch.cuda.manual_seed, "cpu": torch.manual_seed, "default": torch.manual_seed} |
| 1067 | + BACKEND_EMPTY_CACHE = { |
| 1068 | + "cuda": torch.cuda.empty_cache, |
| 1069 | + "xpu": torch.xpu.empty_cache, |
| 1070 | + "cpu": None, |
| 1071 | + "mps": None, |
| 1072 | + "default": None, |
| 1073 | + } |
| 1074 | + BACKEND_DEVICE_COUNT = { |
| 1075 | + "cuda": torch.cuda.device_count, |
| 1076 | + "xpu": torch.xpu.device_count, |
| 1077 | + "cpu": lambda: 0, |
| 1078 | + "mps": lambda: 0, |
| 1079 | + "default": 0, |
| 1080 | + } |
| 1081 | + BACKEND_MANUAL_SEED = { |
| 1082 | + "cuda": torch.cuda.manual_seed, |
| 1083 | + "xpu": torch.xpu.manual_seed, |
| 1084 | + "cpu": torch.manual_seed, |
| 1085 | + "default": torch.manual_seed, |
| 1086 | + } |
| 1087 | + BACKEND_RESET_PEAK_MEMORY_STATS = { |
| 1088 | + "cuda": torch.cuda.reset_peak_memory_stats(), |
| 1089 | + "xpu": torch.xpu.reset_peak_memory_stats(), |
| 1090 | + "default": None, |
| 1091 | + } |
1064 | 1092 |
|
1065 | 1093 |
|
1066 | 1094 | # This dispatches a defined function according to the accelerator from the function definitions. |
@@ -1091,6 +1119,10 @@ def backend_device_count(device: str): |
1091 | 1119 | return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT) |
1092 | 1120 |
|
1093 | 1121 |
|
| 1122 | +def backend_reset_peak_memory(device: str): |
| 1123 | + return _device_agnostic_dispatch(device, BACKEND_RESET_PEAK_MEMORY_STATS) |
| 1124 | + |
| 1125 | + |
1094 | 1126 | # These are callables which return boolean behaviour flags and can be used to specify some |
1095 | 1127 | # device agnostic alternative where the feature is unsupported. |
1096 | 1128 | def backend_supports_training(device: str): |
@@ -1147,3 +1179,13 @@ def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name |
1147 | 1179 | update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN") |
1148 | 1180 | update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN") |
1149 | 1181 | update_mapping_from_spec(BACKEND_SUPPORTS_TRAINING, "SUPPORTS_TRAINING") |
| 1182 | + update_mapping_from_spec(BACKEND_RESET_PEAK_MEMORY_STATS, "RESET_PEAK_MEM_STATS") |
| 1183 | + |
| 1184 | + |
| 1185 | +@require_torch |
| 1186 | +def flush_memory(device: str, gc_collect=False, reset_mem_stats=False): |
| 1187 | + if gc_collect: |
| 1188 | + gc.collect() |
| 1189 | + if reset_mem_stats: |
| 1190 | + backend_reset_peak_memory(device) |
| 1191 | + backend_empty_cache(device) |
0 commit comments