|
9 | 9 |
|
10 | 10 | from ..pytorch.tensor.quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc, prepare_for_saving, restore_from_saved |
11 | 11 | from transformer_engine.debug.debug_state import TEDebugState |
| 12 | +import transformer_engine_torch as tex |
12 | 13 |
|
13 | 14 | """ |
14 | 15 | This file contains DebugQuantizer and DebugQuantizedTensor objects, which are wrapper along Quantizer and QuantizedTensor |
@@ -235,33 +236,40 @@ def update_quantized( |
235 | 236 | noop_flag: Optional[torch.Tensor] = None, |
236 | 237 | ) -> QuantizedTensor: |
237 | 238 | assert noop_flag is None, "CUDA Graphs are not supported with debug=True!" |
| 239 | + iteration = nvinspect_api.DEBUG_MANAGER._trainer_iteration_count |
238 | 240 | updated_second_gemm = False |
239 | 241 | updated_first_gemm = False |
240 | 242 | if self.parent_quantizer is not None: |
241 | | - if self.first_gemm_usage and self.fp8_quantize_first_gemm: |
242 | | - dst.first_gemm.quantize_(src) |
| 243 | + if dst.first_gemm_tensor is not None and self.fp8_quantize_first_gemm: |
| 244 | + if hasattr(dst.first_gemm_tensor, "quantize_"): |
| 245 | + dst.first_gemm_tensor.quantize_(src, noop_flag=None) |
| 246 | + else: |
| 247 | + tex.quantize(src, self.parent_quantizer, dst.first_gemm_tensor, None) |
243 | 248 | updated_first_gemm = True |
244 | | - elif self.second_gemm_usage and self.fp8_quantize_first_gemm: |
245 | | - dst.second_gemm.quantize_(src) |
| 249 | + if dst.second_gemm_tensor is not None and self.fp8_quantize_second_gemm: |
| 250 | + if hasattr(dst.second_gemm_tensor, "quantize_"): |
| 251 | + dst.second_gemm_tensor.quantize_(src, noop_flag=None) |
| 252 | + else: |
| 253 | + tex.quantize(src, self.parent_quantizer, dst.second_gemm_tensor, None) |
246 | 254 | updated_second_gemm = True |
247 | 255 |
|
248 | 256 | if self.process_tensor_second_gemm: |
249 | 257 | out = nvinspect_api.transformer_engine.process_tensor( |
250 | 258 | layer_name=self.layer_name, tensor_name=self.tensor_name, |
251 | | - gemm=self.second_gemm_gemm_name, tensor=src, |
252 | | - default_quantizer=self.parent_quantizer, out=dst.second_gemm) |
| 259 | + gemm=self.second_gemm_name, tensor=src, |
| 260 | + default_quantizer=self.parent_quantizer, out=dst.second_gemm_tensor, iteration=iteration) |
253 | 261 | assert out is None, "API call nvinspect_api.transformer_engine.process_tensor with out != None should return None" |
254 | 262 | updated_second_gemm = True |
255 | 263 | if self.process_tensor_first_gemm: |
256 | 264 | nvinspect_api.transformer_engine.process_tensor( |
257 | 265 | layer_name=self.layer_name, tensor_name=self.tensor_name, |
258 | | - gemm=self.process_tensor_first_gemm, default_quantizer=self.parent_quantizer, |
259 | | - tensor=src, out=dst.first_gemm) |
| 266 | + gemm=self.first_gemm_name, tensor=src, |
| 267 | + default_quantizer=self.parent_quantizer, out=dst.first_gemm_tensor, iteration=iteration) |
260 | 268 | updated_first_gemm = True |
261 | 269 | if not updated_second_gemm: |
262 | | - dst.second_gemm.copy_(src) |
| 270 | + dst.second_gemm_tensor.copy_(src) |
263 | 271 | if updated_second_gemm and not updated_first_gemm: |
264 | | - dst.first_gemm.copy_(src) |
| 272 | + dst.first_gemm_tensor.copy_(src) |
265 | 273 | # if updated_first_gemm and updated_second_gemm, then |
266 | 274 | # dst.second_gemm and dst.first_gemm. is the same tensor, |
267 | 275 | # and it is already updated. |
@@ -313,11 +321,12 @@ def restore_from_saved(self, tensors): |
313 | 321 | restore_from_saved([self.first_gemm_tensor, self.second_gemm_tensor], tensors, return_saved_tensors=True) |
314 | 322 | return saved_tensors |
315 | 323 |
|
316 | | - def _quantize(self, tensor): |
| 324 | + def quantize_(self, tensor, *, noop_flag = None): |
| 325 | + assert noop_flag is None, "CUDA Graphs are not supported with debug=True!" |
317 | 326 | self.quantizer.update_quantized(tensor, self) |
318 | 327 |
|
319 | 328 | def dequantize(self, *, dtype = torch.float32): |
320 | | - return self.first_gemm.dequantize().to(dtype) |
| 329 | + return self.first_gemm_tensor.dequantize().to(dtype) |
321 | 330 |
|
322 | 331 | def get_tensor(self, transpose:bool): |
323 | 332 | # Is used in the python gemm() to get tensor or transpose of the tensor. |
|
0 commit comments