Skip to content

Commit b6e0767

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent a54a743 commit b6e0767

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

transformer_engine/debug/features/dump_tensors.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,7 @@ def initialize(self, root_log_dir: str):
4848
if dist.is_initialized():
4949
self.rank = dist.get_rank()
5050

51-
self.root_dir = os.path.join(
52-
root_log_dir, "tensor_dumps", f"rank_{self.rank}"
53-
)
51+
self.root_dir = os.path.join(root_log_dir, "tensor_dumps", f"rank_{self.rank}")
5452
os.makedirs(self.root_dir, exist_ok=True)
5553

5654
debug_api.log_message(
@@ -75,8 +73,7 @@ def save_tensor(
7573
"""Save a tensor (or dict of tensors) to a file."""
7674
if self.root_dir is None:
7775
raise RuntimeError(
78-
"[TE DumpTensors] TensorLogger not initialized. "
79-
"Call initialize() first."
76+
"[TE DumpTensors] TensorLogger not initialized. Call initialize() first."
8077
)
8178

8279
safe_layer_name = self._sanitize_name(layer_name)
@@ -324,7 +321,9 @@ def _get_extended_tensors_mxfp8(tensor: MXFP8Tensor) -> Dict[str, Optional[torch
324321
if tensor._rowwise_scale_inv is not None:
325322
result["rowwise_block_scale_inv"] = tensor._rowwise_scale_inv.view(torch.float8_e8m0fnu)
326323
if tensor._columnwise_scale_inv is not None:
327-
result["columnwise_block_scale_inv"] = tensor._columnwise_scale_inv.view(torch.float8_e8m0fnu)
324+
result["columnwise_block_scale_inv"] = tensor._columnwise_scale_inv.view(
325+
torch.float8_e8m0fnu
326+
)
328327

329328
return result
330329

@@ -343,7 +342,9 @@ def _get_extended_tensors_nvfp4(tensor: NVFP4Tensor) -> Dict[str, Optional[torch
343342
if tensor._rowwise_scale_inv is not None:
344343
result["rowwise_block_scale_inv"] = tensor._rowwise_scale_inv.view(torch.float8_e4m3fn)
345344
if tensor._columnwise_scale_inv is not None:
346-
result["columnwise_block_scale_inv"] = tensor._columnwise_scale_inv.view(torch.float8_e4m3fn)
345+
result["columnwise_block_scale_inv"] = tensor._columnwise_scale_inv.view(
346+
torch.float8_e4m3fn
347+
)
347348

348349
# Input absolute maximum value (used to compute tensor scale)
349350
if tensor._amax_rowwise is not None:

0 commit comments

Comments
 (0)