6
6
```
7
7
"""
8
8
9
+ import json
9
10
import os
10
11
import shutil
11
12
import tempfile
@@ -170,12 +171,7 @@ def test_whole_workflow(self):
170
171
171
172
# Define a simple kernel directly in the test function
172
173
@triton .jit
173
- def test_kernel (
174
- x_ptr ,
175
- y_ptr ,
176
- n_elements ,
177
- BLOCK_SIZE : tl .constexpr ,
178
- ):
174
+ def test_kernel (x_ptr , y_ptr , n_elements , BLOCK_SIZE : tl .constexpr ):
179
175
pid = tl .program_id (axis = 0 )
180
176
block_start = pid * BLOCK_SIZE
181
177
offsets = block_start + tl .arange (0 , BLOCK_SIZE )
@@ -189,48 +185,147 @@ def test_kernel(
189
185
def run_test_kernel (x ):
190
186
n_elements = x .numel ()
191
187
y = torch .empty_like (x )
192
- BLOCK_SIZE = 256 # Smaller block size for simplicity
188
+ BLOCK_SIZE = 256
193
189
grid = (triton .cdiv (n_elements , BLOCK_SIZE ),)
194
190
test_kernel [grid ](x , y , n_elements , BLOCK_SIZE )
195
191
return y
196
192
193
+ # Set up test environment
197
194
temp_dir = tempfile .mkdtemp ()
198
- print (f"Temporary directory: { temp_dir } " )
199
195
temp_dir_logs = os .path .join (temp_dir , "logs" )
200
- os .makedirs (temp_dir_logs , exist_ok = True )
201
196
temp_dir_parsed = os .path .join (temp_dir , "parsed_output" )
197
+ os .makedirs (temp_dir_logs , exist_ok = True )
202
198
os .makedirs (temp_dir_parsed , exist_ok = True )
199
+ print (f"Temporary directory: { temp_dir } " )
203
200
204
- tritonparse .structured_logging .init (temp_dir_logs )
201
+ # Initialize logging
202
+ tritonparse .structured_logging .init (temp_dir_logs , enable_trace_launch = True )
205
203
206
- # Generate some triton compilation activity to create log files
204
+ # Generate test data and run kernels
207
205
torch .manual_seed (0 )
208
206
size = (512 , 512 ) # Smaller size for faster testing
209
207
x = torch .randn (size , device = self .cuda_device , dtype = torch .float32 )
210
- run_test_kernel (x ) # Run the simple kernel
208
+
209
+ # Run kernel twice to generate compilation and launch events
210
+ run_test_kernel (x )
211
+ run_test_kernel (x )
211
212
torch .cuda .synchronize ()
212
213
213
- # Check that temp_dir_logs folder has content
214
+ # Verify log directory
214
215
assert os .path .exists (
215
216
temp_dir_logs
216
217
), f"Log directory { temp_dir_logs } does not exist."
217
218
log_files = os .listdir (temp_dir_logs )
218
- assert (
219
- len (log_files ) > 0
220
- ), f"No log files found in { temp_dir_logs } . Expected log files to be generated during Triton compilation."
219
+ assert len (log_files ) > 0 , (
220
+ f"No log files found in { temp_dir_logs } . "
221
+ "Expected log files to be generated during Triton compilation."
222
+ )
221
223
print (f"Found { len (log_files )} log files in { temp_dir_logs } : { log_files } " )
222
224
225
+ def parse_log_line (line : str , line_num : int ) -> dict | None :
226
+ """Parse a single log line and extract event data"""
227
+ try :
228
+ return json .loads (line .strip ())
229
+ except json .JSONDecodeError as e :
230
+ print (f" Line { line_num } : JSON decode error - { e } " )
231
+ return None
232
+
233
+ def process_event_data (
234
+ event_data : dict , line_num : int , event_counts : dict
235
+ ) -> None :
236
+ """Process event data and update counts"""
237
+ try :
238
+ event_type = event_data .get ("event_type" )
239
+ if event_type is None :
240
+ return
241
+
242
+ if event_type in event_counts :
243
+ event_counts [event_type ] += 1
244
+ print (
245
+ f" Line { line_num } : event_type = '{ event_type } ' (count: { event_counts [event_type ]} )"
246
+ )
247
+ else :
248
+ print (
249
+ f" Line { line_num } : event_type = '{ event_type } ' (not tracked)"
250
+ )
251
+ except (KeyError , TypeError ) as e :
252
+ print (f" Line { line_num } : Data structure error - { e } " )
253
+
254
+ def count_events_in_file (file_path : str , event_counts : dict ) -> None :
255
+ """Count events in a single log file"""
256
+ print (f"Checking event types in: { os .path .basename (file_path )} " )
257
+
258
+ with open (file_path , "r" ) as f :
259
+ for line_num , line in enumerate (f , 1 ):
260
+ event_data = parse_log_line (line , line_num )
261
+ if event_data :
262
+ process_event_data (event_data , line_num , event_counts )
263
+
264
+ def check_event_type_counts_in_logs (log_dir : str ) -> dict :
265
+ """Count 'launch' and unique 'compilation' events in all log files"""
266
+ event_counts = {"launch" : 0 }
267
+ # Track unique compilation hashes
268
+ compilation_hashes = set ()
269
+
270
+ for log_file in os .listdir (log_dir ):
271
+ if log_file .endswith (".ndjson" ):
272
+ log_file_path = os .path .join (log_dir , log_file )
273
+ with open (log_file_path , "r" ) as f :
274
+ for line_num , line in enumerate (f , 1 ):
275
+ try :
276
+ event_data = json .loads (line .strip ())
277
+ event_type = event_data .get ("event_type" )
278
+ if event_type == "launch" :
279
+ event_counts ["launch" ] += 1
280
+ print (
281
+ f" Line { line_num } : event_type = 'launch' (count: { event_counts ['launch' ]} )"
282
+ )
283
+ elif event_type == "compilation" :
284
+ # Extract hash from compilation metadata
285
+ compilation_hash = (
286
+ event_data .get ("payload" , {})
287
+ .get ("metadata" , {})
288
+ .get ("hash" )
289
+ )
290
+ if compilation_hash :
291
+ compilation_hashes .add (compilation_hash )
292
+ print (
293
+ f" Line { line_num } : event_type = 'compilation' (unique hash: { compilation_hash [:8 ]} ...)"
294
+ )
295
+ except (json .JSONDecodeError , KeyError , TypeError ) as e :
296
+ print (f" Line { line_num } : Error processing line - { e } " )
297
+
298
+ # Add the count of unique compilation hashes to the event_counts
299
+ event_counts ["compilation" ] = len (compilation_hashes )
300
+ print (
301
+ f"Event type counts: { event_counts } (unique compilation hashes: { len (compilation_hashes )} )"
302
+ )
303
+ return event_counts
304
+
305
+ # Verify event counts
306
+ event_counts = check_event_type_counts_in_logs (temp_dir_logs )
307
+ assert (
308
+ event_counts ["compilation" ] == 1
309
+ ), f"Expected 1 unique 'compilation' hash, found { event_counts ['compilation' ]} "
310
+ assert (
311
+ event_counts ["launch" ] == 2
312
+ ), f"Expected 2 'launch' events, found { event_counts ['launch' ]} "
313
+ print (
314
+ "✓ Verified correct event type counts: 1 unique compilation hash, 2 launch events"
315
+ )
316
+
317
+ # Test parsing functionality
223
318
tritonparse .utils .unified_parse (
224
319
source = temp_dir_logs , out = temp_dir_parsed , overwrite = True
225
320
)
226
-
227
- # Clean up temporary directory
228
321
try :
229
- # Check that parsed output directory has files
322
+ # Verify parsing output
230
323
parsed_files = os .listdir (temp_dir_parsed )
231
324
assert len (parsed_files ) > 0 , "No files found in parsed output directory"
232
325
finally :
326
+ # Clean up
233
327
shutil .rmtree (temp_dir )
328
+ print ("✓ Cleaned up temporary directory" )
234
329
235
330
236
331
if __name__ == "__main__" :
0 commit comments