|
108 | 108 | for library in LOADABLE_CLASSES: |
109 | 109 | LIBRARIES.append(library) |
110 | 110 |
|
111 | | -SUPPORTED_DEVICE_MAP = ["balanced"] |
| 111 | +# TODO: support single-device namings |
| 112 | +SUPPORTED_DEVICE_MAP = ["balanced", "cuda"] |
112 | 113 |
|
113 | 114 | logger = logging.get_logger(__name__) |
114 | 115 |
|
@@ -988,12 +989,15 @@ def load_module(name, value): |
988 | 989 | _maybe_warn_for_wrong_component_in_quant_config(init_dict, quantization_config) |
989 | 990 | for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."): |
990 | 991 | # 7.1 device_map shenanigans |
991 | | - if final_device_map is not None and len(final_device_map) > 0: |
992 | | - component_device = final_device_map.get(name, None) |
993 | | - if component_device is not None: |
994 | | - current_device_map = {"": component_device} |
995 | | - else: |
996 | | - current_device_map = None |
| 992 | + if final_device_map is not None: |
| 993 | + if isinstance(final_device_map, dict) and len(final_device_map) > 0: |
| 994 | + component_device = final_device_map.get(name, None) |
| 995 | + if component_device is not None: |
| 996 | + current_device_map = {"": component_device} |
| 997 | + else: |
| 998 | + current_device_map = None |
| 999 | + elif isinstance(final_device_map, str): |
| 1000 | + current_device_map = final_device_map |
997 | 1001 |
|
998 | 1002 | # 7.2 - now that JAX/Flax is an official framework of the library, we might load from Flax names |
999 | 1003 | class_name = class_name[4:] if class_name.startswith("Flax") else class_name |
|
0 commit comments