Skip to content

Commit 32e2e05

Browse files
committed
fixes
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
1 parent abf2eba commit 32e2e05

File tree

19 files changed

+79
-67
lines changed

19 files changed

+79
-67
lines changed

tests/pytorch/debug/test_api_features.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# limitations under the License.
1313

1414
import torch
15-
from transformer_engine.pytorch.tensor import Float8Tensor, Float8Quantizer
15+
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer
1616

1717
import nvdlfw_inspect.api as nvinspect_api
1818

tests/pytorch/debug/test_numerics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from transformer_engine.common.recipe import DelayedScaling, Format
2020
from transformer_engine.pytorch.constants import TE_DType
2121
from transformer_engine.pytorch.fp8 import _default_sf_compute
22-
from transformer_engine.pytorch.tensor import Float8Quantizer
22+
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
2323
from transformer_engine.pytorch.module.base import (
2424
_2X_ACC_DGRAD,
2525
_2X_ACC_FPROP,

tests/pytorch/debug/test_sanity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from transformer_engine.common.recipe import DelayedScaling, Format
2020
from transformer_engine.pytorch.constants import TE_DType
2121
from transformer_engine.pytorch.fp8 import _default_sf_compute
22-
from transformer_engine.pytorch.tensor import Float8Quantizer
22+
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
2323

2424
from test_numerics import create_config_file
2525

tests/pytorch/distributed/test_fusible_ops.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,8 @@
2121
import transformer_engine.common.recipe
2222
import transformer_engine.pytorch as te
2323
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
24-
<<<<<<< HEAD
25-
from transformer_engine.pytorch.tensor import Float8Quantizer
26-
=======
2724
from transformer_engine.pytorch.tensor import QuantizedTensor
2825
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
29-
>>>>>>> origin/release_v2.0
3026
import transformer_engine.pytorch.ops as te_ops
3127
from transformer_engine.pytorch.ops._common import is_float8_tensor
3228
from transformer_engine.pytorch.utils import is_bf16_compatible

transformer_engine/debug/debug_quantization.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from ..pytorch.tensor.quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc, prepare_for_saving, restore_from_saved
1111
from transformer_engine.debug.debug_state import TEDebugState
12+
import transformer_engine_torch as tex
1213

1314
"""
1415
This file contains DebugQuantizer and DebugQuantizedTensor objects, which are wrapper along Quantizer and QuantizedTensor
@@ -235,33 +236,40 @@ def update_quantized(
235236
noop_flag: Optional[torch.Tensor] = None,
236237
) -> QuantizedTensor:
237238
assert noop_flag is None, "CUDA Graphs are not supported with debug=True!"
239+
iteration = nvinspect_api.DEBUG_MANAGER._trainer_iteration_count
238240
updated_second_gemm = False
239241
updated_first_gemm = False
240242
if self.parent_quantizer is not None:
241-
if self.first_gemm_usage and self.fp8_quantize_first_gemm:
242-
dst.first_gemm.quantize_(src)
243+
if dst.first_gemm_tensor is not None and self.fp8_quantize_first_gemm:
244+
if hasattr(dst.first_gemm_tensor, "quantize_"):
245+
dst.first_gemm_tensor.quantize_(src, noop_flag=None)
246+
else:
247+
tex.quantize(src, self.parent_quantizer, dst.first_gemm_tensor, None)
243248
updated_first_gemm = True
244-
elif self.second_gemm_usage and self.fp8_quantize_first_gemm:
245-
dst.second_gemm.quantize_(src)
249+
if dst.second_gemm_tensor is not None and self.fp8_quantize_second_gemm:
250+
if hasattr(dst.second_gemm_tensor, "quantize_"):
251+
dst.second_gemm_tensor.quantize_(src, noop_flag=None)
252+
else:
253+
tex.quantize(src, self.parent_quantizer, dst.second_gemm_tensor, None)
246254
updated_second_gemm = True
247255

248256
if self.process_tensor_second_gemm:
249257
out = nvinspect_api.transformer_engine.process_tensor(
250258
layer_name=self.layer_name, tensor_name=self.tensor_name,
251-
gemm=self.second_gemm_gemm_name, tensor=src,
252-
default_quantizer=self.parent_quantizer, out=dst.second_gemm)
259+
gemm=self.second_gemm_name, tensor=src,
260+
default_quantizer=self.parent_quantizer, out=dst.second_gemm_tensor, iteration=iteration)
253261
assert out is None, "API call nvinspect_api.transformer_engine.process_tensor with out != None should return None"
254262
updated_second_gemm = True
255263
if self.process_tensor_first_gemm:
256264
nvinspect_api.transformer_engine.process_tensor(
257265
layer_name=self.layer_name, tensor_name=self.tensor_name,
258-
gemm=self.process_tensor_first_gemm, default_quantizer=self.parent_quantizer,
259-
tensor=src, out=dst.first_gemm)
266+
gemm=self.first_gemm_name, tensor=src,
267+
default_quantizer=self.parent_quantizer, out=dst.first_gemm_tensor, iteration=iteration)
260268
updated_first_gemm = True
261269
if not updated_second_gemm:
262-
dst.second_gemm.copy_(src)
270+
dst.second_gemm_tensor.copy_(src)
263271
if updated_second_gemm and not updated_first_gemm:
264-
dst.first_gemm.copy_(src)
272+
dst.first_gemm_tensor.copy_(src)
265273
# if updated_first_gemm and updated_second_gemm, then
266274
# dst.second_gemm and dst.first_gemm. is the same tensor,
267275
# and it is already updated.
@@ -313,11 +321,12 @@ def restore_from_saved(self, tensors):
313321
restore_from_saved([self.first_gemm_tensor, self.second_gemm_tensor], tensors, return_saved_tensors=True)
314322
return saved_tensors
315323

316-
def _quantize(self, tensor):
324+
def quantize_(self, tensor, *, noop_flag = None):
325+
assert noop_flag is None, "CUDA Graphs are not supported with debug=True!"
317326
self.quantizer.update_quantized(tensor, self)
318327

319328
def dequantize(self, *, dtype = torch.float32):
320-
return self.first_gemm.dequantize().to(dtype)
329+
return self.first_gemm_tensor.dequantize().to(dtype)
321330

322331
def get_tensor(self, transpose:bool):
323332
# Is used in the python gemm() to get tensor or transpose of the tensor.

transformer_engine/debug/debug_state.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def initialize(cls):
3030
def reset(cls):
3131
from .features.utils.stats_buffer import STATS_BUFFERS, StatsBuffers
3232
STATS_BUFFERS.reset()
33+
cls.debug_enabled = None
3334
cls.layers_initialized.clear()
3435

3536
@classmethod

transformer_engine/debug/features/api.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def use_process_tensor(self, *args, **kwargs):
100100
return False
101101

102102
def process_tensor(self, *args, **kwargs):
103-
return kwargs["tensor"]
103+
raise RuntimeError("use_process_tensor() returned True, process_tensor() was invoked, but it is not handled by any API.")
104104

105105
def look_at_tensor_before_process(self, *args, **kwargs):
106106
pass
@@ -167,9 +167,10 @@ def routing_condition(self, api_name, config, layer_name, feature_obj, **kwargs)
167167
return status, modified_config
168168

169169
def output_assertions_hook(self, api_name, ret, **kwargs):
170-
if api_name in {"process_tensor"}:
171-
assert type(ret) in [torch.Tensor, Float8Tensor, Float8TensorBase, MXFP8Tensor, MXFP8TensorBase], \
172-
f"This API {api_name} must return a tensor."
170+
pass
171+
#if api_name in {"process_tensor"}:
172+
# assert type(ret) in [torch.Tensor, Float8Tensor, Float8TensorBase, MXFP8Tensor, MXFP8TensorBase], \
173+
# f"This API {api_name} must return a tensor."
173174

174175

175176

transformer_engine/debug/features/disable_fp8_layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class DisableFp8Layer:
3333
"""
3434

3535
@api_method
36-
def fp8_gemm(self, config, layer_name, gemm):
36+
def fp8_gemm(self, config, layer_name, *args, **kwargs):
3737
for key in config:
3838
if key not in ["enabled", "gemm"]:
3939
raise ValueError(f"[NVTORCH INSPECT ERROR] Unexpected key in config: \"{key}\".")

transformer_engine/debug/features/log_tensor_stats.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,10 @@ def look_at_tensor_before_process(self, config, layer_name,
6666
options = (config.get('start_step', None), config.get('end_step', None), config.get('start_end_list', None),)
6767
skip_reduction = False
6868
reduction_group = nvinspect_api.get_tensor_reduction_group()
69-
if self.tensor_name == "weight":
69+
if tensor_name == "weight":
7070
if TEDebugState.weight_tensor_tp_group_reduce:
71-
reduction_group = self.tp_group
71+
pass
72+
#reduction_group = self.tp_group
7273
else:
7374
skip_reduction = True
7475

transformer_engine/debug/features/utils/stats_buffer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def log(self):
102102
combiner = STATS[stat_name][1]
103103
stat_value = combiner(gathered_helper_stats)
104104

105-
MetricLogger.log_scalar(f"{self.layer_name}_{self.tensor_name}_{stat_name}", stat_value.float(), self.iteration)
105+
MetricLogger.log_scalar(f"{self.layer_name}_{self.tensor_name}_{stat_name}", stat_value, self.iteration)
106106
output[(self.layer_name, self.tensor_name, stat_name, self.iteration)] = stat_value # for debuggin purpouses
107107
self._reset_before_next_step()
108108
return output

0 commit comments

Comments
 (0)