|
9 | 9 | import _torch_ipex as core |
10 | 10 |
|
11 | 11 | DEVICE = 'dpcpp' |
12 | | -def enable_auto_optimization(mixed_dtype = None, train = False, configure_file = None): |
13 | | - r""" Enable auto-mixed-precision to improve performance. |
14 | 12 |
|
15 | | - The auto-mixed-precision auto reorders the tensor to the specified low precision data type. |
16 | | - You don't need to convert the input tensors and the model to the specified data type manually, |
17 | | - the extension will do it automatically and then dispatch the extension backend to accelerate |
18 | | - computation |
19 | | -
|
20 | | - Args: |
21 | | - mixed_dtype(torch.dtype): Auto reorder the input tensors to the specified low precision data type |
22 | | - and dispatch to oneDNN backend for computation |
23 | | -
|
24 | | - """ |
25 | | - if mixed_dtype != None: |
26 | | - core.enable_auto_dnnl() |
27 | | - enable_auto_mix_precision(mixed_dtype, train, configure_file) |
28 | | - |
29 | | -def get_auto_optimization(): |
30 | | - return get_auto_mix_precision |
31 | | - |
32 | | -def get_train(): |
33 | | - return core.get_train() |
34 | | - |
35 | | -def enable_auto_mix_precision(mixed_dtype = torch.bfloat16, train = False, configure_file = None): |
36 | | - if mixed_dtype == torch.bfloat16: |
37 | | - core.enable_mix_bf16_fp32() |
38 | | - core.disable_mix_int8_fp32() |
39 | | - elif mixed_dtype == torch.int8 or mixed_dtype == torch.uint8: |
40 | | - core.enable_mix_int8_fp32() |
41 | | - core.disable_mix_bf16_fp32() |
42 | | - if configure_file != None: |
43 | | - core.disable_int8_calibration() |
44 | | - f = open(configure_file) |
45 | | - configures = json.load(f) |
46 | | - core.load_indicators_file(configures) |
47 | | - else: |
48 | | - warnings.warn("please not forget do calibration before doing validation step") |
49 | | - else: |
50 | | - core.disable_mix_int8_fp32() |
51 | | - core.disable_mix_bf16_fp32() |
52 | | - core.set_execution_mode(train=train) |
53 | | - |
54 | | -def get_auto_mix_precision(): |
55 | | - if core.get_mix_bf16_fp32(): |
56 | | - return torch.bfloat16 |
57 | | - elif core.get_mix_int8_fp32(): |
58 | | - return torch.int8 |
59 | | - else: |
60 | | - return None |
61 | | - |
62 | | -''' |
63 | | -def quarry_int8_configure(model, inputs_shape): |
64 | | - dummy_input = torch.randn(input_shapes).to(DEVICE) |
65 | | - core.enable_mix_int8_fp32() |
66 | | - with torch.no_grad(): |
67 | | - y = model(dummy_input) |
68 | | - observer_configures = core.get_int8_observer_configures() |
69 | | - return observer_configures |
70 | | -''' |
71 | | - |
72 | | -def calibration_reset(): |
73 | | - if core.get_int8_calibration(): |
74 | | - core.calibration_reset() |
75 | | - else: |
76 | | - raise ValueError("please first run enable_calibration before calibration reset") |
| 13 | +class AmpConf(object): |
| 14 | + def __init__(self, mixed_dtype = torch.bfloat16, configure_file = None): |
| 15 | + self.dtype = mixed_dtype |
| 16 | + self.configure_file = configure_file |
| 17 | + |
| 18 | + # for int8 path, if user give a exited configure file, load it. |
| 19 | + if self.configure_file != None and self.dtype != torch.bfloat16: |
| 20 | + if os.path.exists(self.configure_file) and os.stat(self.configure_file).st_size != 0: |
| 21 | + with open(self.configure_file, 'r') as f: |
| 22 | + configures = json.load(f) |
| 23 | + core.load_indicators_file(configures) |
| 24 | + else: |
| 25 | + assert False, 'Can not load a empty file or none existed file, plese first do calibartion step' |
| 26 | + |
| 27 | + # for int8 quantization, will save the date after doing calibration step. |
| 28 | + def save(self, configure_file): |
| 29 | + core.add_indicators() |
| 30 | + configures = core.get_int8_configures() |
| 31 | + with open(configure_file, 'w') as fp: |
| 32 | + json.dump(configures, fp, indent = 4) |
77 | 33 |
|
78 | 34 | class _DecoratorContextManager: |
79 | 35 | """Allow a context manager to be used as a decorator, copy form pytorch FW""" |
@@ -102,22 +58,80 @@ def generator_context(*args, **kwargs): |
102 | 58 | break |
103 | 59 | return generator_context |
104 | 60 |
|
105 | | -class int8_calibration(_DecoratorContextManager): |
106 | | - def __init__(self, file_name, observer_configure=None): |
107 | | - #self.observer_configure = observer_configure |
108 | | - self.configure_file = file_name |
| 61 | +def get_auto_mix_precision(): |
| 62 | + if core.get_mix_bf16_fp32(): |
| 63 | + return torch.bfloat16 |
| 64 | + elif core.get_mix_int8_fp32(): |
| 65 | + return torch.int8 |
| 66 | + else: |
| 67 | + return None |
| 68 | + |
| 69 | +def enable_auto_optimization(mixed_dtype = None, train = False): |
| 70 | + r""" Enable auto-mixed-precision to improve performance for global scope. |
| 71 | +
|
| 72 | + The auto-mixed-precision auto reorders the tensor to the specified low precision data type. |
| 73 | + You don't need to convert the input tensors and the model to the specified data type manually, |
| 74 | + the extension will do it automatically and then dispatch the extension backend to accelerate |
| 75 | + computation |
| 76 | +
|
| 77 | + Args: |
| 78 | + mixed_dtype(torch.dtype): Auto reorder the input tensors to the specified low precision data type |
| 79 | + and dispatch to oneDNN backend for computation, can be torch.bfloat16 or None. |
| 80 | + """ |
| 81 | + if mixed_dtype != None: |
| 82 | + core.enable_auto_dnnl() |
| 83 | + running_mode = 'training' if train else 'inference' |
| 84 | + enable_auto_mix_precision(AmpConf(mixed_dtype), running_mode).__enter__() |
| 85 | + |
| 86 | +def get_auto_optimization(): |
| 87 | + return get_auto_mix_precision |
| 88 | + |
| 89 | +def get_train(): |
| 90 | + return core.get_train() |
| 91 | + |
| 92 | +class enable_auto_mix_precision(_DecoratorContextManager): |
| 93 | + def __init__(self, conf, running_mode = 'inference'): |
| 94 | + self.pre_mixed_dtype = get_auto_mix_precision() |
| 95 | + self.pre_running_mode = get_train() |
| 96 | + self.pre_calibration_state = core.get_int8_calibration() |
| 97 | + self.mixed_dtype = conf.dtype |
| 98 | + self.running_mode = running_mode |
109 | 99 |
|
110 | 100 | def __enter__(self): |
111 | | - if not core.get_mix_int8_fp32(): |
112 | | - raise ValueError("please first run enable_auto_mix_precision(torch.int8) before int8 calibration") |
113 | | - core.enable_int8_calibration() |
114 | | - #core.set_int8_observer_configure(self.observer_configure) |
| 101 | + if self.mixed_dtype == torch.bfloat16: |
| 102 | + core.enable_mix_bf16_fp32() |
| 103 | + core.disable_mix_int8_fp32() |
| 104 | + elif self.mixed_dtype == torch.int8: |
| 105 | + core.enable_mix_int8_fp32() |
| 106 | + core.disable_mix_bf16_fp32() |
| 107 | + if self.running_mode == 'inference': |
| 108 | + core.disable_int8_calibration() |
| 109 | + elif self.running_mode == 'calibration': |
| 110 | + core.enable_int8_calibration() |
| 111 | + else: |
| 112 | + assert False, 'int8 quantization only suport inference and calibration running mode' |
| 113 | + else: |
| 114 | + core.disable_mix_int8_fp32() |
| 115 | + core.disable_mix_bf16_fp32() |
| 116 | + core.set_execution_mode(train = True if self.running_mode == 'training' else False) |
115 | 117 |
|
116 | 118 | def __exit__(self, *args): |
117 | | - core.disable_int8_calibration() |
118 | | - core.add_indicators() |
119 | | - configures = core.get_int8_configures() |
120 | | - with open(self.configure_file, 'w') as fp: |
121 | | - json.dump(configures, fp, indent=4) |
122 | | - return False |
| 119 | + if self.mixed_dtype == torch.int8: |
| 120 | + if self.running_mode == 'calibration': |
| 121 | + core.calibration_reset() |
| 122 | + # restore previous state |
| 123 | + if self.pre_calibration_state: |
| 124 | + core.enable_int8_calibration() |
| 125 | + else: |
| 126 | + core.disable_int8_calibration() |
| 127 | + if self.pre_mixed_dtype == torch.bfloat16: |
| 128 | + core.enable_mix_bf16_fp32() |
| 129 | + core.disable_mix_int8_fp32() |
| 130 | + elif self.pre_mixed_dtype == torch.int8: |
| 131 | + core.enable_mix_int8_fp32() |
| 132 | + core.disable_mix_bf16_fp32() |
| 133 | + else: |
| 134 | + core.disable_mix_int8_fp32() |
| 135 | + core.disable_mix_bf16_fp32() |
| 136 | + core.set_execution_mode(train = self.pre_running_mode) |
123 | 137 |
|
0 commit comments