Skip to content

Commit acdac2c

Browse files
FindHaofacebook-github-bot
authored andcommitted
Enhance structured logging initialization and test coverage (#32)
Summary: - Updated `tritonparse.structured_logging.init` to accept an `enable_trace_launch` parameter, allowing for improved logging capabilities. - Modified test files to utilize the new logging feature and verify event type counts in generated log files, ensuring accurate tracking of 'launch' and 'compilation' events. - Added checks in `test_tritonparse.py` to assert the expected counts of log events, enhancing test robustness. These changes aim to improve the logging functionality and ensure comprehensive testing of the logging behavior. Pull Request resolved: #32 Reviewed By: davidberard98 Differential Revision: D78310275 Pulled By: FindHao fbshipit-source-id: a9147a5b83d2aae21a555cd8651b5a018deb87ef
1 parent 639b1c5 commit acdac2c

File tree

3 files changed

+123
-21
lines changed

3 files changed

+123
-21
lines changed

tests/test_add.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
```
88
"""
99

10+
import os
11+
1012
import torch
1113
import triton
1214
import triton.language as tl
@@ -15,7 +17,9 @@
1517
import tritonparse.utils
1618

1719
log_path = "./logs"
18-
tritonparse.structured_logging.init(log_path)
20+
tritonparse.structured_logging.init(log_path, enable_trace_launch=True)
21+
22+
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "0"
1923

2024

2125
@triton.jit

tests/test_tritonparse.py

Lines changed: 114 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
```
77
"""
88

9+
import json
910
import os
1011
import shutil
1112
import tempfile
@@ -170,12 +171,7 @@ def test_whole_workflow(self):
170171

171172
# Define a simple kernel directly in the test function
172173
@triton.jit
173-
def test_kernel(
174-
x_ptr,
175-
y_ptr,
176-
n_elements,
177-
BLOCK_SIZE: tl.constexpr,
178-
):
174+
def test_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
179175
pid = tl.program_id(axis=0)
180176
block_start = pid * BLOCK_SIZE
181177
offsets = block_start + tl.arange(0, BLOCK_SIZE)
@@ -189,48 +185,147 @@ def test_kernel(
189185
def run_test_kernel(x):
190186
n_elements = x.numel()
191187
y = torch.empty_like(x)
192-
BLOCK_SIZE = 256 # Smaller block size for simplicity
188+
BLOCK_SIZE = 256
193189
grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
194190
test_kernel[grid](x, y, n_elements, BLOCK_SIZE)
195191
return y
196192

193+
# Set up test environment
197194
temp_dir = tempfile.mkdtemp()
198-
print(f"Temporary directory: {temp_dir}")
199195
temp_dir_logs = os.path.join(temp_dir, "logs")
200-
os.makedirs(temp_dir_logs, exist_ok=True)
201196
temp_dir_parsed = os.path.join(temp_dir, "parsed_output")
197+
os.makedirs(temp_dir_logs, exist_ok=True)
202198
os.makedirs(temp_dir_parsed, exist_ok=True)
199+
print(f"Temporary directory: {temp_dir}")
203200

204-
tritonparse.structured_logging.init(temp_dir_logs)
201+
# Initialize logging
202+
tritonparse.structured_logging.init(temp_dir_logs, enable_trace_launch=True)
205203

206-
# Generate some triton compilation activity to create log files
204+
# Generate test data and run kernels
207205
torch.manual_seed(0)
208206
size = (512, 512) # Smaller size for faster testing
209207
x = torch.randn(size, device=self.cuda_device, dtype=torch.float32)
210-
run_test_kernel(x) # Run the simple kernel
208+
209+
# Run kernel twice to generate compilation and launch events
210+
run_test_kernel(x)
211+
run_test_kernel(x)
211212
torch.cuda.synchronize()
212213

213-
# Check that temp_dir_logs folder has content
214+
# Verify log directory
214215
assert os.path.exists(
215216
temp_dir_logs
216217
), f"Log directory {temp_dir_logs} does not exist."
217218
log_files = os.listdir(temp_dir_logs)
218-
assert (
219-
len(log_files) > 0
220-
), f"No log files found in {temp_dir_logs}. Expected log files to be generated during Triton compilation."
219+
assert len(log_files) > 0, (
220+
f"No log files found in {temp_dir_logs}. "
221+
"Expected log files to be generated during Triton compilation."
222+
)
221223
print(f"Found {len(log_files)} log files in {temp_dir_logs}: {log_files}")
222224

225+
def parse_log_line(line: str, line_num: int) -> dict | None:
226+
"""Parse a single log line and extract event data"""
227+
try:
228+
return json.loads(line.strip())
229+
except json.JSONDecodeError as e:
230+
print(f" Line {line_num}: JSON decode error - {e}")
231+
return None
232+
233+
def process_event_data(
234+
event_data: dict, line_num: int, event_counts: dict
235+
) -> None:
236+
"""Process event data and update counts"""
237+
try:
238+
event_type = event_data.get("event_type")
239+
if event_type is None:
240+
return
241+
242+
if event_type in event_counts:
243+
event_counts[event_type] += 1
244+
print(
245+
f" Line {line_num}: event_type = '{event_type}' (count: {event_counts[event_type]})"
246+
)
247+
else:
248+
print(
249+
f" Line {line_num}: event_type = '{event_type}' (not tracked)"
250+
)
251+
except (KeyError, TypeError) as e:
252+
print(f" Line {line_num}: Data structure error - {e}")
253+
254+
def count_events_in_file(file_path: str, event_counts: dict) -> None:
255+
"""Count events in a single log file"""
256+
print(f"Checking event types in: {os.path.basename(file_path)}")
257+
258+
with open(file_path, "r") as f:
259+
for line_num, line in enumerate(f, 1):
260+
event_data = parse_log_line(line, line_num)
261+
if event_data:
262+
process_event_data(event_data, line_num, event_counts)
263+
264+
def check_event_type_counts_in_logs(log_dir: str) -> dict:
265+
"""Count 'launch' and unique 'compilation' events in all log files"""
266+
event_counts = {"launch": 0}
267+
# Track unique compilation hashes
268+
compilation_hashes = set()
269+
270+
for log_file in os.listdir(log_dir):
271+
if log_file.endswith(".ndjson"):
272+
log_file_path = os.path.join(log_dir, log_file)
273+
with open(log_file_path, "r") as f:
274+
for line_num, line in enumerate(f, 1):
275+
try:
276+
event_data = json.loads(line.strip())
277+
event_type = event_data.get("event_type")
278+
if event_type == "launch":
279+
event_counts["launch"] += 1
280+
print(
281+
f" Line {line_num}: event_type = 'launch' (count: {event_counts['launch']})"
282+
)
283+
elif event_type == "compilation":
284+
# Extract hash from compilation metadata
285+
compilation_hash = (
286+
event_data.get("payload", {})
287+
.get("metadata", {})
288+
.get("hash")
289+
)
290+
if compilation_hash:
291+
compilation_hashes.add(compilation_hash)
292+
print(
293+
f" Line {line_num}: event_type = 'compilation' (unique hash: {compilation_hash[:8]}...)"
294+
)
295+
except (json.JSONDecodeError, KeyError, TypeError) as e:
296+
print(f" Line {line_num}: Error processing line - {e}")
297+
298+
# Add the count of unique compilation hashes to the event_counts
299+
event_counts["compilation"] = len(compilation_hashes)
300+
print(
301+
f"Event type counts: {event_counts} (unique compilation hashes: {len(compilation_hashes)})"
302+
)
303+
return event_counts
304+
305+
# Verify event counts
306+
event_counts = check_event_type_counts_in_logs(temp_dir_logs)
307+
assert (
308+
event_counts["compilation"] == 1
309+
), f"Expected 1 unique 'compilation' hash, found {event_counts['compilation']}"
310+
assert (
311+
event_counts["launch"] == 2
312+
), f"Expected 2 'launch' events, found {event_counts['launch']}"
313+
print(
314+
"✓ Verified correct event type counts: 1 unique compilation hash, 2 launch events"
315+
)
316+
317+
# Test parsing functionality
223318
tritonparse.utils.unified_parse(
224319
source=temp_dir_logs, out=temp_dir_parsed, overwrite=True
225320
)
226-
227-
# Clean up temporary directory
228321
try:
229-
# Check that parsed output directory has files
322+
# Verify parsing output
230323
parsed_files = os.listdir(temp_dir_parsed)
231324
assert len(parsed_files) > 0, "No files found in parsed output directory"
232325
finally:
326+
# Clean up
233327
shutil.rmtree(temp_dir)
328+
print("✓ Cleaned up temporary directory")
234329

235330

236331
if __name__ == "__main__":

tritonparse/structured_logging.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -981,13 +981,16 @@ def init_basic(trace_folder: Optional[str] = None):
981981
maybe_enable_trace_launch()
982982

983983

984-
def init(trace_folder: Optional[str] = None):
984+
def init(trace_folder: Optional[str] = None, enable_trace_launch: bool = False):
985985
"""
986986
This function is a wrapper around init_basic() that also setup the compilation listener.
987987
988988
Args:
989989
trace_folder (Optional[str]): The folder to store the trace files.
990990
"""
991+
global TRITON_TRACE_LAUNCH
992+
if enable_trace_launch:
993+
TRITON_TRACE_LAUNCH = True
991994
import triton
992995

993996
init_basic(trace_folder)

0 commit comments

Comments
 (0)