Skip to content

Commit fa06e41

Browse files
FindHaometa-codesync[bot]
authored andcommitted
Add comprehensive TensorBlobManager tests and cleanup (#162)
Summary: ## Overview This PR adds comprehensive test coverage for the `TensorBlobManager` functionality and performs minor cleanup in the test infrastructure. ## Changes ### 1. Test Infrastructure Cleanup - **Removed unnecessary torch inductor metrics reset** (`tests/test_tritonparse.py:111`) - Removed `torch._inductor.metrics.reset()` call from the `clear_all_caches()` helper function - This line was not essential for cache clearing functionality ### 2. New Test: `test_tensor_blob_manager` Added a comprehensive test case (~280 lines) to validate the TensorBlobManager functionality with four distinct test scenarios: #### Test 1: Mixed Tensor Sizes with Compression Threshold - Tests automatic compression based on size threshold (1MB) - Validates that small tensors (2KB, 400KB) are saved as `.bin` files - Validates that large tensors (20MB, 400MB) are saved as `.bin.gz` compressed files - Confirms both `.bin` and `.bin.gz` formats can be loaded correctly #### Test 2: Deduplication - Tests tensor deduplication when the same tensor is reused multiple times - Runs the same kernel 3 times with identical input - Verifies that blob count is reduced through deduplication (< 6 blobs for 3 launches) - Ensures at least 1 blob is created for the deduplicated input #### Test 3: Quota Limit Enforcement - Tests storage quota mechanism with a 60KB limit - Verifies first tensor is saved successfully - Confirms storage is disabled after quota is exceeded - Validates that no additional blobs are saved once quota is reached - Resets global variables to prevent test pollution #### Test 4: Disabled Storage - Tests that tensor blob storage can be explicitly disabled - Verifies no blob files are created when `enable_tensor_blob_storage=False` - Confirms `saved_tensors` directory remains empty ### Test Utilities The new test includes helper functions: - `collect_blob_files()`: Collects all `.bin` and `.bin.gz` files from saved_tensors directory - `count_all_blobs()`: Counts total number of blob files ## Testing Approach - Uses `unittest.skipUnless` to ensure tests only run when CUDA is available - Employs temporary directories for isolated test execution - Includes proper cleanup with `TEST_KEEP_OUTPUT` flag support for debugging - Synchronizes CUDA operations to ensure accurate testing ## Impact - ✅ Improves test coverage for tensor blob storage functionality - ✅ Validates compression, deduplication, and quota mechanisms - ✅ Ensures backward compatibility with disabled storage mode - ✅ Minor cleanup improves test infrastructure maintainability Pull Request resolved: #162 Test Plan: ```bash % python -m unittest tests.test_tritonparse -v -k test_tensor_blob_manager test_tensor_blob_manager (tests.test_tritonparse.TestTritonparseCUDA.test_tensor_blob_manager) Test TensorBlobManager functionality with context manager ... === Test 1: Mixed Tensor Sizes with Compression Threshold === Found 4 .bin files: /tmp/tmp4xaavx1u/saved_tensors/00/0074f9a0e0a6cd25693532e59b2dc8be68303dbb29c47d3b80c0572b6185fc571d5061ba1b20aff86bcf77cd3e6e023cb7ef8efda0ef5a53ff5961ef77fb71ba.bin (3625 bytes) /tmp/tmp4xaavx1u/saved_tensors/0e/0eac7954d5ade855de02ebf13bb0d98c404d3be28ee7e4ac70b7650bbe0a86c70dc6f16ec4f9ec31b0d4a0c0a52e2264925a67ded46e3007f39fb72fba28a789.bin (3625 bytes) /tmp/tmp4xaavx1u/saved_tensors/18/1840a6ec1ab8e5261aff496068189a3653095143a249c82bb73c2facde311269cf9c2ab27a90090114eca28150189d1f89f51d33702dd3f3a3183beb6ce14a77.bin (411177 bytes) /tmp/tmp4xaavx1u/saved_tensors/f8/f81de0a3530c8fbe1aabc8aca52f4d6750925ba19180924850ab5310a966a377918b3677614e38697a7ee5298f612b076d57e1cac89bc1add5f67e9dc9d00cea.bin (411177 bytes) Found 4 .bin.gz files: /tmp/tmp4xaavx1u/saved_tensors/32/32e064bc2899066269955ea172112f2a3707ca336e508ceba6717449e20d4971a24961947676618614db58f687edf0856f6e2b93dccdbeb63ea5e2bff81cc02a.bin.gz (21168 bytes) /tmp/tmp4xaavx1u/saved_tensors/bd/bdb593a4c71ad247c128bc9b378ab313cf00c776279234537f0538bca2fd91b4434f9c83e3fdcf6379808564d883d90ea32dbc9186157c77eebb926576a2d7fe.bin.gz (19424577 bytes) /tmp/tmp4xaavx1u/saved_tensors/36/368214d5a7f553271ed7697d19a1c46c3c588385cfb6832400dc2ad43fe77108679b7c0bc50c71b26252acb0c15652a8a6d4839eb3e921fa328a59203c7a6538.bin.gz (408465 bytes) /tmp/tmp4xaavx1u/saved_tensors/cd/cda0384d432651a66629e00d8ba726f194f916b83dc44aa7248b146c2dace4917004b09e8de20d668f61d48f0b00294334e144a19c64daf493f4c2a74119cd33.bin.gz (408549 bytes) ✓ Mixed sizes: 4 uncompressed (.bin), 4 compressed (.bin.gz) ✓ Compression effective: largest file compressed to 18.52 MB ✓ Successfully loaded .bin file ✓ Successfully loaded .bin.gz file ✓ Both formats (.bin and .bin.gz) verified tritonparse log file list: /tmp/tmpnn37863a/log_file_list.json INFO:tritonparse:Copying parsed logs from /tmp/tmpnn37863a to /tmp/tmpblkbqgii ================================================================================ 📁 TRITONPARSE PARSING RESULTS ================================================================================ 📂 Parsed files directory: /tmp/tmpblkbqgii 📊 Total files generated: 2 📄 Generated files: -------------------------------------------------- 1. 📝 dedicated_log_triton_trace_findhao__mapped.ndjson.gz (9.3KB) 2. 📝 log_file_list.json (181B) ================================================================================ ✅ Parsing completed successfully! ================================================================================ === Test 2: Deduplication === ✓ Deduplication working: 4 unique blob(s) for 3 launches (< 6 without dedup) WARNING:SourceMapping:No output file for kernel hash 5aeee63456fc3e7aa058576d7bece1ab3cd040db7b92ff2e8988536d850a5897, skipping. tritonparse log file list: /tmp/tmp7g7a2i9q/log_file_list.json INFO:tritonparse:Copying parsed logs from /tmp/tmp7g7a2i9q to /tmp/tmp4cr79vzr ================================================================================ 📁 TRITONPARSE PARSING RESULTS ================================================================================ 📂 Parsed files directory: /tmp/tmp4cr79vzr 📊 Total files generated: 1 📄 Generated files: -------------------------------------------------- 1. 📝 log_file_list.json (106B) ================================================================================ ✅ Parsing completed successfully! ================================================================================ === Test 3: Quota Limit === WARNING:tritonparse.structured_logging:⚠️ TENSOR BLOB STORAGE DISABLED: Storage quota would be exceeded: 0.00GB > 0.00GB limit INFO:tritonparse.structured_logging:📊 Final Tensor blob stats: 1 saved (1 total, 0 dedup), 0.00GB compressed (0.00GB uncompressed), compression ratio: 1.00x Blobs after first kernel launch: 1 Blobs after second kernel launch: 1 ✓ Quota enforced: 1 blob(s) saved before quota limit WARNING:SourceMapping:No output file for kernel hash 5aeee63456fc3e7aa058576d7bece1ab3cd040db7b92ff2e8988536d850a5897, skipping. tritonparse log file list: /tmp/tmplcl_o1dt/log_file_list.json INFO:tritonparse:Copying parsed logs from /tmp/tmplcl_o1dt to /tmp/tmpvo1mmy39 ================================================================================ 📁 TRITONPARSE PARSING RESULTS ================================================================================ 📂 Parsed files directory: /tmp/tmpvo1mmy39 📊 Total files generated: 1 📄 Generated files: -------------------------------------------------- 1. 📝 log_file_list.json (106B) ================================================================================ ✅ Parsing completed successfully! ================================================================================ ✓ Quota limit test passed (storage disabled when quota exceeded) === Test 4: Disabled Storage === ✓ Storage correctly disabled when enable_tensor_blob_storage=False WARNING:SourceMapping:No output file for kernel hash 5aeee63456fc3e7aa058576d7bece1ab3cd040db7b92ff2e8988536d850a5897, skipping. tritonparse log file list: /tmp/tmplamlg5c2/log_file_list.json INFO:tritonparse:Copying parsed logs from /tmp/tmplamlg5c2 to /tmp/tmph824g9g7 ================================================================================ 📁 TRITONPARSE PARSING RESULTS ================================================================================ 📂 Parsed files directory: /tmp/tmph824g9g7 📊 Total files generated: 1 📄 Generated files: -------------------------------------------------- 1. 📝 log_file_list.json (106B) ================================================================================ ✅ Parsing completed successfully! ================================================================================ ✓ Cleaned up all test output directories ok ---------------------------------------------------------------------- Ran 1 test in 9.121s OK ``` Reviewed By: adamomainz Differential Revision: D84215701 Pulled By: FindHao fbshipit-source-id: 2d17ac33d6c1c67dca0231f428a524f8e2413e77
1 parent 7a9de8f commit fa06e41

File tree

1 file changed

+275
-1
lines changed

1 file changed

+275
-1
lines changed

tests/test_tritonparse.py

Lines changed: 275 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ def clear_all_caches(*kernels):
105105
# Reset torch compiler state
106106
torch.compiler.reset()
107107
torch._dynamo.reset()
108-
torch._inductor.metrics.reset()
109108
print("✓ Reset torch compiler, dynamo, and inductor state")
110109

111110
# Clear Triton kernel device caches for all provided kernels
@@ -1168,6 +1167,281 @@ def test_reproducer_end_to_end(self):
11681167
shutil.rmtree(temp_dir)
11691168
print("✓ Cleaned up temporary directory")
11701169

1170+
@unittest.skipUnless(torch.cuda.is_available(), "CUDA not available")
1171+
def test_tensor_blob_manager(self):
1172+
"""Test TensorBlobManager functionality with context manager"""
1173+
1174+
# Setup fresh cache for this test
1175+
test_cache_dir, prev_cache_dir = self.setup_test_with_fresh_cache()
1176+
1177+
# Define a simple kernel that accepts tensor inputs
1178+
@triton.jit
1179+
def tensor_input_kernel(
1180+
input_ptr,
1181+
output_ptr,
1182+
n_elements,
1183+
BLOCK_SIZE: tl.constexpr,
1184+
):
1185+
pid = tl.program_id(axis=0)
1186+
block_start = pid * BLOCK_SIZE
1187+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
1188+
mask = offsets < n_elements
1189+
1190+
x = tl.load(input_ptr + offsets, mask=mask)
1191+
y = x * 2.0
1192+
tl.store(output_ptr + offsets, y, mask=mask)
1193+
1194+
def run_kernel(input_tensor):
1195+
n_elements = input_tensor.numel()
1196+
output = torch.empty_like(input_tensor)
1197+
BLOCK_SIZE = 256
1198+
grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
1199+
tensor_input_kernel[grid](input_tensor, output, n_elements, BLOCK_SIZE)
1200+
return output
1201+
1202+
def collect_blob_files(manager_dir_path):
1203+
"""Collect all .bin and .bin.gz files from saved_tensors directory."""
1204+
saved_tensors_dir = os.path.join(manager_dir_path, "saved_tensors")
1205+
bin_files = []
1206+
gz_files = []
1207+
1208+
if not os.path.exists(saved_tensors_dir):
1209+
return bin_files, gz_files
1210+
1211+
for subdir in os.listdir(saved_tensors_dir):
1212+
subdir_path = os.path.join(saved_tensors_dir, subdir)
1213+
if os.path.isdir(subdir_path):
1214+
for filename in os.listdir(subdir_path):
1215+
full_path = os.path.join(subdir_path, filename)
1216+
if filename.endswith(".bin.gz"):
1217+
gz_files.append(full_path)
1218+
elif filename.endswith(".bin"):
1219+
bin_files.append(full_path)
1220+
1221+
return bin_files, gz_files
1222+
1223+
def count_all_blobs(manager_dir_path):
1224+
"""Count total number of blob files (.bin and .bin.gz)."""
1225+
bin_files, gz_files = collect_blob_files(manager_dir_path)
1226+
return len(bin_files) + len(gz_files)
1227+
1228+
# Prepare test data
1229+
torch.manual_seed(0)
1230+
1231+
# === Test 1: Mixed tensor sizes with compression threshold ===
1232+
print("\n=== Test 1: Mixed Tensor Sizes with Compression Threshold ===")
1233+
temp_output_dir_1 = tempfile.mkdtemp()
1234+
1235+
with tritonparse.context_manager.TritonParseManager(
1236+
enable_trace_launch=True,
1237+
enable_tensor_blob_storage=True,
1238+
out=temp_output_dir_1,
1239+
) as manager:
1240+
# Test different tensor sizes around the 1MB compression threshold
1241+
test_cases = [
1242+
((512,), "Tiny 2KB"), # 2KB < 1MB -> .bin
1243+
((100 * 1024,), "Medium 400KB"), # 400KB < 1MB -> .bin
1244+
((5 * 1024 * 1024,), "Large 20MB"), # 20MB > 1MB -> .bin.gz
1245+
((100 * 1024 * 1024,), "Very large 400MB"), # 400MB > 1MB -> .bin.gz
1246+
]
1247+
1248+
# Create tensors and run kernels
1249+
for size, desc in test_cases:
1250+
x = torch.randn(size, device=self.cuda_device, dtype=torch.float32)
1251+
y = run_kernel(x)
1252+
y.sum()
1253+
torch.cuda.synchronize()
1254+
1255+
# Collect and verify blob files
1256+
bin_files, gz_files = collect_blob_files(manager.dir_path)
1257+
assert len(bin_files) + len(gz_files) > 0, "No blob files found"
1258+
1259+
print(f"Found {len(bin_files)} .bin files:")
1260+
for f in bin_files:
1261+
print(f" {f} ({os.path.getsize(f)} bytes)")
1262+
print(f"Found {len(gz_files)} .bin.gz files:")
1263+
for f in gz_files:
1264+
print(f" {f} ({os.path.getsize(f)} bytes)")
1265+
1266+
# Verify correct number of files (2 small uncompressed, 2 large compressed)
1267+
assert (
1268+
len(bin_files) == 4
1269+
), f"Expected 4 .bin files (2KB, 400KB), got {len(bin_files)}"
1270+
assert (
1271+
len(gz_files) == 4
1272+
), f"Expected 4 .bin.gz files (20MB, 400MB), got {len(gz_files)}"
1273+
1274+
print(
1275+
f"✓ Mixed sizes: {len(bin_files)} uncompressed (.bin), {len(gz_files)} compressed (.bin.gz)"
1276+
)
1277+
1278+
# Verify both formats can be loaded
1279+
from tritonparse.tools.load_tensor import load_tensor
1280+
1281+
if bin_files:
1282+
loaded = load_tensor(bin_files[0])
1283+
assert loaded is not None, "Failed to load .bin file"
1284+
print("✓ Successfully loaded .bin file")
1285+
1286+
if gz_files:
1287+
loaded = load_tensor(gz_files[0])
1288+
assert loaded is not None, "Failed to load .bin.gz file"
1289+
print("✓ Successfully loaded .bin.gz file")
1290+
1291+
print("✓ Both formats (.bin and .bin.gz) verified")
1292+
1293+
# === Test 2: Deduplication ===
1294+
print("\n=== Test 2: Deduplication ===")
1295+
temp_output_dir_2 = tempfile.mkdtemp()
1296+
1297+
with tritonparse.context_manager.TritonParseManager(
1298+
enable_trace_launch=True,
1299+
enable_tensor_blob_storage=True,
1300+
out=temp_output_dir_2,
1301+
) as manager:
1302+
# Use the same tensor multiple times
1303+
x = torch.randn((512,), device=self.cuda_device, dtype=torch.float32)
1304+
1305+
# Run kernel 3 times with same input
1306+
for i in range(3):
1307+
y = run_kernel(x)
1308+
y.sum()
1309+
torch.cuda.synchronize()
1310+
1311+
# Count blob files
1312+
# Note: The system may save both input and output tensors.
1313+
# - Input tensor x: reused 3 times → should deduplicate to 1 blob
1314+
# - Output tensors y: 3 separate allocations → may be 3 blobs (if different) or 1 blob (if identical)
1315+
# Expected: fewer blobs than total tensor references due to deduplication
1316+
blob_count = count_all_blobs(manager.dir_path)
1317+
# With deduplication, we should have significantly fewer blobs than 6 (3 inputs + 3 outputs)
1318+
assert (
1319+
blob_count < 6
1320+
), f"Deduplication should reduce blob count, got {blob_count} for 3 launches"
1321+
# We expect at least 1 blob (the deduplicated input)
1322+
assert blob_count >= 1, f"Should have at least 1 blob, got {blob_count}"
1323+
print(
1324+
f"✓ Deduplication working: {blob_count} unique blob(s) for 3 launches (< 6 without dedup)"
1325+
)
1326+
1327+
# === Test 3: Quota limit ===
1328+
print("\n=== Test 3: Quota Limit ===")
1329+
temp_output_dir_3 = tempfile.mkdtemp()
1330+
1331+
# Calculate quota to allow exactly one tensor to be saved
1332+
# A 10000 element float32 tensor = 10000 * 4 bytes = 40KB
1333+
# After torch.save serialization, it will be larger (includes metadata)
1334+
# Compressed size will be smaller for random data (but still substantial)
1335+
# Set quota to ~60KB to allow first tensor but not second
1336+
# Note: Random data doesn't compress as well as zeros
1337+
quota_for_one_tensor = 60 * 1024 # 60KB should fit one serialized tensor
1338+
1339+
with tritonparse.context_manager.TritonParseManager(
1340+
enable_trace_launch=True,
1341+
enable_tensor_blob_storage=True,
1342+
tensor_storage_quota=quota_for_one_tensor,
1343+
out=temp_output_dir_3,
1344+
) as manager:
1345+
# Create first tensor - should be saved successfully
1346+
large_x1 = torch.randn(
1347+
(10000,), device=self.cuda_device, dtype=torch.float32
1348+
)
1349+
y1 = run_kernel(large_x1)
1350+
y1.sum()
1351+
torch.cuda.synchronize()
1352+
1353+
# Check that first tensor was saved
1354+
blob_count_after_first = count_all_blobs(manager.dir_path)
1355+
print(f" Blobs after first kernel launch: {blob_count_after_first}")
1356+
1357+
# Create second tensor - should exceed quota and trigger storage disable
1358+
large_x2 = torch.randn(
1359+
(10000,), device=self.cuda_device, dtype=torch.float32
1360+
)
1361+
y2 = run_kernel(large_x2)
1362+
y2.sum()
1363+
torch.cuda.synchronize()
1364+
1365+
# Verify quota enforcement
1366+
blob_count_final = count_all_blobs(manager.dir_path)
1367+
print(f" Blobs after second kernel launch: {blob_count_final}")
1368+
1369+
# We expect at least 1 blob was saved (from first launch)
1370+
assert (
1371+
blob_count_after_first >= 1
1372+
), f"First tensor should be saved, got {blob_count_after_first} blobs"
1373+
1374+
# After quota exceeded, no more blobs should be added
1375+
# (blob_count_final should equal blob_count_after_first or be slightly higher
1376+
# if some outputs were saved before quota was hit)
1377+
assert (
1378+
blob_count_final <= blob_count_after_first + 1
1379+
), f"Quota should prevent saving many more blobs: first={blob_count_after_first}, final={blob_count_final}"
1380+
1381+
print(
1382+
f"✓ Quota enforced: {blob_count_after_first} blob(s) saved before quota limit"
1383+
)
1384+
1385+
# The test passes if it doesn't crash - storage should be disabled after quota exceeded
1386+
print("✓ Quota limit test passed (storage disabled when quota exceeded)")
1387+
1388+
# Reset global variables to default after Test 3 to avoid polluting Test 4
1389+
tritonparse.structured_logging.TRITONPARSE_TENSOR_STORAGE_QUOTA = (
1390+
100 * 1024 * 1024 * 1024
1391+
) # 100GB default
1392+
tritonparse.structured_logging.TRITONPARSE_SAVE_TENSOR_BLOBS = (
1393+
False # Reset to default (disabled)
1394+
)
1395+
1396+
# === Test 4: Disabled storage ===
1397+
print("\n=== Test 4: Disabled Storage ===")
1398+
temp_output_dir_4 = tempfile.mkdtemp()
1399+
1400+
# When storage is explicitly disabled, don't set quota to avoid confusion
1401+
with tritonparse.context_manager.TritonParseManager(
1402+
enable_trace_launch=True,
1403+
enable_tensor_blob_storage=False, # Explicitly disabled
1404+
out=temp_output_dir_4,
1405+
) as manager:
1406+
x = torch.randn((512,), device=self.cuda_device, dtype=torch.float32)
1407+
y = run_kernel(x)
1408+
y.sum()
1409+
torch.cuda.synchronize()
1410+
1411+
# Verify no saved_tensors directory or it's empty
1412+
total_blobs = count_all_blobs(manager.dir_path)
1413+
assert (
1414+
total_blobs == 0
1415+
), f"Expected no blobs when storage disabled, found {total_blobs}"
1416+
print("✓ Storage correctly disabled when enable_tensor_blob_storage=False")
1417+
1418+
# Clean up all test outputs
1419+
try:
1420+
if TEST_KEEP_OUTPUT:
1421+
print(
1422+
f"\n✓ Preserving output directories (TEST_KEEP_OUTPUT=1):\n"
1423+
f" Test 1: {temp_output_dir_1}\n"
1424+
f" Test 2: {temp_output_dir_2}\n"
1425+
f" Test 3: {temp_output_dir_3}\n"
1426+
f" Test 4: {temp_output_dir_4}"
1427+
)
1428+
else:
1429+
for temp_dir in [
1430+
temp_output_dir_1,
1431+
temp_output_dir_2,
1432+
temp_output_dir_3,
1433+
temp_output_dir_4,
1434+
]:
1435+
if os.path.exists(temp_dir):
1436+
shutil.rmtree(temp_dir)
1437+
print("✓ Cleaned up all test output directories")
1438+
except Exception as e:
1439+
print(f"Warning: Failed to clean up output directories: {e}")
1440+
1441+
finally:
1442+
# Cleanup test-specific cache
1443+
self.cleanup_test_cache(test_cache_dir, prev_cache_dir)
1444+
11711445

11721446
if __name__ == "__main__":
11731447
unittest.main()

0 commit comments

Comments
 (0)