-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdefense.py
More file actions
579 lines (463 loc) · 19.6 KB
/
defense.py
File metadata and controls
579 lines (463 loc) · 19.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
import torch
import threading
import random
from typing import Callable, Tuple, Any, Optional
# =============================================================================
# Store original CUDA functions at module load time (before any monkey-patching)
# =============================================================================
_original_elapsed_time = torch.cuda.Event.elapsed_time
_original_record = torch.cuda.Event.record
_original_synchronize = torch.cuda.synchronize
def defend_against_thread_injection(
kernel: Callable,
*args,
**kwargs
) -> Tuple[bool, str, Any]:
"""
Defense against thread injection attack.
Thread injection spawns a background thread to do computation while
returning an empty tensor immediately. This cheats timing but passes
correctness checks since the thread finishes before verification.
Defense: Compare thread count before and after kernel execution.
Returns:
(passed, message, output)
"""
before = threading.active_count()
output = kernel(*args, **kwargs)
after = threading.active_count()
if after > before:
return False, "Kernel spawned background thread", output
else:
return True, "Thread injection check passed", output
def defend_against_stream_injection(
kernel: Callable,
*args,
ratio_threshold: float = 1.5,
**kwargs
) -> Tuple[bool, str, Any, Optional[float]]:
"""
Defense against stream injection attack using hybrid approach.
Stream injection runs the kernel on a separate CUDA stream, causing
events recorded on the default stream to miss the actual computation time.
Hybrid Defense:
1. Run kernel with ORIGINAL timing (events on default stream, no extra sync)
2. Run kernel with DEFENSE timing (sync all streams before end event)
3. Compare: if (defense_time / original_time) > ratio_threshold, kernel is malicious
4. If ratio is within threshold, use original timing (no overhead)
Args:
kernel: The kernel function to test
*args: Arguments to pass to the kernel
ratio_threshold: Maximum allowed ratio of defense_time/original_time (default 5.0)
**kwargs: Keyword arguments to pass to the kernel
Returns:
(passed, message, output, timing_ms)
- passed: True if no stream injection detected
- message: Description of result
- output: Kernel output
- timing_ms: The timing to use (original if legit, defense if suspicious)
"""
# Warmup
_ = kernel(*args, **kwargs)
torch.cuda.synchronize()
# =========================================================================
# RUN 10 ITERATIONS WITH RANDOM ORDER
# =========================================================================
num_iterations = 10
original_times = []
defense_times = []
output = None
for i in range(num_iterations):
# Randomly decide which timing method runs first
run_original_first = random.choice([True, False])
def time_original():
s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
s.record()
out = kernel(*args, **kwargs)
e.record()
torch.cuda.synchronize()
return s.elapsed_time(e), out
def time_defense():
s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
s.record()
out = kernel(*args, **kwargs)
torch.cuda.synchronize() # Force sync ALL streams
e.record()
torch.cuda.synchronize()
return s.elapsed_time(e), out
if run_original_first:
t, output = time_original()
original_times.append(t)
t, output = time_defense()
defense_times.append(t)
else:
t, output = time_defense()
defense_times.append(t)
t, output = time_original()
original_times.append(t)
# Use median to reduce noise
original_time_ms = sorted(original_times)[num_iterations // 2]
defense_time_ms = sorted(defense_times)[num_iterations // 2]
# =========================================================================
# COMPARE: Large ratio indicates stream injection
# =========================================================================
# Avoid division by zero - if original_time is ~0, it's definitely suspicious
if original_time_ms < 0.001: # Less than 1µs
ratio = float('inf')
else:
ratio = defense_time_ms / original_time_ms
if ratio > ratio_threshold:
return (
False,
f"Stream injection detected! Original: {original_time_ms:.3f}ms, "
f"Defense: {defense_time_ms:.3f}ms, Ratio: {ratio:.1f}x (threshold: {ratio_threshold}x)",
output,
defense_time_ms # Use defense timing for malicious kernels
)
else:
return (
True,
f"Stream injection check passed. Ratio: {ratio:.2f}x (within {ratio_threshold}x threshold)",
output,
original_time_ms # Use original timing (no overhead) for legit kernels
)
def defend_against_lazy_evaluation(
kernel: Callable,
*args,
expected_device: Optional[torch.device] = None,
**kwargs
) -> Tuple[bool, str, Any]:
"""
Defense against lazy evaluation attack.
Lazy evaluation returns a tensor subclass that defers computation until
certain methods (like __eq__) are called. This cheats timing because
the actual computation happens during correctness check, not during timing.
Defense: Validate that output is a standard torch.Tensor with allocated memory.
Args:
kernel: The kernel function to test
*args: Arguments to pass to the kernel
expected_device: Expected device for output (inferred from first arg if None)
**kwargs: Keyword arguments to pass to the kernel
Returns:
(passed, message, output)
"""
output = kernel(*args, **kwargs)
# Infer expected device from first input argument
if expected_device is None:
for arg in args:
if isinstance(arg, torch.Tensor):
expected_device = arg.device
break
if expected_device is None:
expected_device = torch.device('cuda')
# =========================================================================
# Validate tensor is real and materialized, not a lazy hack
# =========================================================================
# Check 1: Must be a tensor
if not isinstance(output, torch.Tensor):
return False, f"Not a tensor: {type(output)}", output
# Check 2: Must be standard torch.Tensor, not a subclass
if type(output).__name__ not in ['Tensor', 'Parameter']:
return False, f"Is {type(output).__name__}, not standard torch.Tensor", output
# Check 3: Must be on correct device
if output.device != expected_device:
return False, f"Wrong device: {output.device} (expected {expected_device})", output
# Check 4: Must have allocated storage
storage_size = output.untyped_storage().size()
if storage_size == 0:
return False, "No allocated storage (likely lazy)", output
# Check 5: Storage pointer must be valid
ptr = output.data_ptr()
if ptr == 0:
return False, "Storage pointer is null (likely lazy)", output
return True, "Lazy evaluation check passed", output
def defend_against_precision_downgrade(
kernel: Callable,
*args,
expected_dtype: Optional[torch.dtype] = None,
**kwargs
) -> Tuple[bool, str, Any]:
"""
Defense against precision downgrade attack.
Precision downgrade uses lower precision (e.g., float16 instead of float32)
to speed up computation while sacrificing accuracy.
Defense: Check if output dtype matches expected dtype.
Args:
kernel: The kernel function to test
*args: Arguments to pass to the kernel
expected_dtype: Expected dtype for output (inferred from first arg if None)
**kwargs: Keyword arguments to pass to the kernel
Returns:
(passed, message, output)
"""
output = kernel(*args, **kwargs)
# Infer expected dtype from first input tensor
if expected_dtype is None:
for arg in args:
if isinstance(arg, torch.Tensor):
expected_dtype = arg.dtype
break
if expected_dtype is None:
expected_dtype = torch.float32
# Check dtype
if output.dtype != expected_dtype:
return False, f"Precision downgrade detected: output is {output.dtype}, expected {expected_dtype}", output
return True, "Precision downgrade check passed", output
def defend_against_elapsed_time_monkey_patching() -> Tuple[bool, str]:
"""
Defense against monkey-patching attack. 😭
Monkey-patching replaces critical CUDA timing functions with fake versions
that return instant times, making any kernel appear fast.
Defense: Check if torch.cuda.Event.elapsed_time, record, or synchronize
have been modified from their original implementations.
Returns:
(passed, message)
"""
patched = []
# Check elapsed_time
if torch.cuda.Event.elapsed_time is not _original_elapsed_time:
patched.append("torch.cuda.Event.elapsed_time")
# Check record
if torch.cuda.Event.record is not _original_record:
patched.append("torch.cuda.Event.record")
# Check synchronize
if torch.cuda.synchronize is not _original_synchronize:
patched.append("torch.cuda.synchronize")
if patched:
return False, f"Monkey-patching detected: {', '.join(patched)}"
else:
return True, "Monkey-patching check passed"
def run_all_defenses(
kernel: Callable,
*args,
**kwargs
) -> Tuple[bool, list, Any]:
"""
Run all defense checks against kernel attacks.
Args:
kernel: The kernel function to test
*args: Arguments to pass to the kernel
**kwargs: Keyword arguments to pass to the kernel
Returns:
(all_passed, results, output)
- all_passed: True if all defense checks passed
- results: List of (defense_name, passed, message) tuples
- output: The kernel output (if any check succeeded)
"""
results = []
output = None
# Defense 1: Stream injection (hybrid approach)
passed, message, output, timing = defend_against_stream_injection(kernel, *args, **kwargs)
results.append(("stream_injection", passed, message))
# Defense 2: Thread injection
passed, message, output = defend_against_thread_injection(kernel, *args, **kwargs)
results.append(("thread_injection", passed, message))
# Defense 3: Lazy evaluation
passed, message, output = defend_against_lazy_evaluation(kernel, *args, **kwargs)
results.append(("lazy_evaluation", passed, message))
# Defense 4: Precision downgrade
passed, message, output = defend_against_precision_downgrade(kernel, *args, **kwargs)
results.append(("precision_downgrade", passed, message))
# Defense 5: Elapsed time monkey patching (doesn't need kernel execution)
passed, message = defend_against_elapsed_time_monkey_patching()
results.append(("elapsed_time_monkey_patching", passed, message))
all_passed = all(r[1] for r in results)
return all_passed, results, output
# =============================================================================
# Testing
# =============================================================================
def legit_kernel(A, B):
return torch.matmul(A, B)
def test_thread_injection():
"""Test the thread injection defense with both legitimate and malicious kernels."""
# Malicious kernel - spawns background thread
def malicious_kernel(A, B):
out = torch.empty(A.size(0), B.size(1), device=A.device)
def compute():
result = torch.matmul(A, B)
out.copy_(result)
t = threading.Thread(target=compute)
t.start()
return out
# Setup
A = torch.randn(10000, 10000).cuda()
B = torch.randn(10000, 10000).cuda()
torch.cuda.synchronize()
print("=" * 50)
print("Testing legitimate kernel:")
print("=" * 50)
passed, msg, output = defend_against_thread_injection(legit_kernel, A, B)
status = "✓ PASS" if passed else "✗ FAIL"
print(f" [{status}] {msg}")
print()
print("=" * 50)
print("Testing malicious kernel (thread injection):")
print("=" * 50)
passed, msg, output = defend_against_thread_injection(malicious_kernel, A, B)
status = "✓ PASS" if passed else "✗ FAIL"
print(f" [{status}] {msg}")
def test_stream_injection():
"""Test the stream injection defense with both legitimate and malicious kernels."""
# Malicious kernel - runs on separate stream (stream injection)
def malicious_kernel(A, B):
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
return torch.matmul(A, B)
# Setup
A = torch.randn(10000, 10000).cuda()
B = torch.randn(10000, 10000).cuda()
torch.cuda.synchronize()
print("=" * 50)
print("Testing legitimate kernel:")
print("=" * 50)
passed, msg, output, timing = defend_against_stream_injection(legit_kernel, A, B)
status = "✓ PASS" if passed else "✗ FAIL"
print(f" [{status}] {msg}")
print()
print("=" * 50)
print("Testing malicious kernel (stream injection):")
print("=" * 50)
passed, msg, output, timing = defend_against_stream_injection(malicious_kernel, A, B)
status = "✓ PASS" if passed else "✗ FAIL"
print(f" [{status}] {msg}")
def test_lazy_evaluation():
"""Test the lazy evaluation defense with both legitimate and malicious kernels."""
# Malicious kernel - returns lazy tensor subclass
class LazyMatmul(torch.Tensor):
@staticmethod
def __new__(cls, A, B):
obj = torch.Tensor._make_subclass(cls, torch.empty(A.size(0), B.size(1), device=A.device))
obj.A, obj.B = A, B
return obj
def __eq__(self, other):
return torch.matmul(self.A, self.B) == other
def malicious_kernel(A, B):
return LazyMatmul(A, B)
# Setup
A = torch.randn(10000, 10000).cuda()
B = torch.randn(10000, 10000).cuda()
torch.cuda.synchronize()
print("=" * 50)
print("Testing legitimate kernel:")
print("=" * 50)
passed, msg, output = defend_against_lazy_evaluation(legit_kernel, A, B)
status = "✓ PASS" if passed else "✗ FAIL"
print(f" [{status}] {msg}")
print()
print("=" * 50)
print("Testing malicious kernel (lazy evaluation):")
print("=" * 50)
passed, msg, output = defend_against_lazy_evaluation(malicious_kernel, A, B)
status = "✓ PASS" if passed else "✗ FAIL"
print(f" [{status}] {msg}")
def test_precision_downgrade():
"""Test the precision downgrade defense with both legitimate and malicious kernels."""
# Malicious kernel - downgrades to float16
def malicious_kernel(A, B):
return torch.matmul(A.half(), B.half()) # Downgrade to float16
# Setup (float32 inputs)
A = torch.randn(10000, 10000).cuda()
B = torch.randn(10000, 10000).cuda()
torch.cuda.synchronize()
print("=" * 50)
print("Testing legitimate kernel:")
print("=" * 50)
passed, msg, output = defend_against_precision_downgrade(legit_kernel, A, B)
status = "✓ PASS" if passed else "✗ FAIL"
print(f" [{status}] {msg}")
print()
print("=" * 50)
print("Testing malicious kernel (precision downgrade):")
print("=" * 50)
passed, msg, output = defend_against_precision_downgrade(malicious_kernel, A, B)
status = "✓ PASS" if passed else "✗ FAIL"
print(f" [{status}] {msg}")
def test_elapsed_time_monkey_patching():
"""Test the monkey-patching defense."""
print("=" * 50)
print("Testing before monkey-patching:")
print("=" * 50)
passed, msg = defend_against_elapsed_time_monkey_patching()
status = "✓ PASS" if passed else "✗ FAIL"
print(f" [{status}] {msg}")
print()
print("=" * 50)
print("Testing after monkey-patching elapsed_time:")
print("=" * 50)
# Apply monkey-patch 😭
def _fake_elapsed_time(self, end_event):
return 0.001 # Fake fast!
torch.cuda.Event.elapsed_time = _fake_elapsed_time
passed, msg = defend_against_elapsed_time_monkey_patching()
status = "✓ PASS" if passed else "✗ FAIL"
print(f" [{status}] {msg}")
# Restore original
torch.cuda.Event.elapsed_time = _original_elapsed_time
print()
print("=" * 50)
print("Testing after restoring original:")
print("=" * 50)
passed, msg = defend_against_elapsed_time_monkey_patching()
status = "✓ PASS" if passed else "✗ FAIL"
print(f" [{status}] {msg}")
def test_run_all_defenses():
"""Test run_all_defenses with legitimate and various malicious kernels."""
# Setup
A = torch.randn(10000, 10000).cuda()
B = torch.randn(10000, 10000).cuda()
torch.cuda.synchronize()
print("=" * 60)
print("Testing legitimate kernel with ALL defenses:")
print("=" * 60)
all_passed, results, output = run_all_defenses(legit_kernel, A, B)
for name, passed, msg in results:
status = "✓ PASS" if passed else "✗ FAIL"
print(f" [{status}] {name}: {msg}")
print(f"\n Overall: {'✓ ALL PASSED' if all_passed else '✗ SOME FAILED'}")
print("\n")
print("=" * 60)
print("Testing stream injection attack with ALL defenses:")
print("=" * 60)
def stream_attack(A, B):
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
return torch.matmul(A, B)
all_passed, results, output = run_all_defenses(stream_attack, A, B)
for name, passed, msg in results:
status = "✓ PASS" if passed else "✗ FAIL"
print(f" [{status}] {name}: {msg}")
print(f"\n Overall: {'✓ ALL PASSED' if all_passed else '✗ SOME FAILED'}")
print("\n")
print("=" * 60)
print("Testing thread injection attack with ALL defenses:")
print("=" * 60)
def thread_attack(A, B):
out = torch.empty(A.size(0), B.size(1), device=A.device)
def compute():
result = torch.matmul(A, B)
out.copy_(result)
t = threading.Thread(target=compute)
t.start()
return out
all_passed, results, output = run_all_defenses(thread_attack, A, B)
for name, passed, msg in results:
status = "✓ PASS" if passed else "✗ FAIL"
print(f" [{status}] {name}: {msg}")
print(f"\n Overall: {'✓ ALL PASSED' if all_passed else '✗ SOME FAILED'}")
print("\n")
print("=" * 60)
print("Testing precision downgrade attack with ALL defenses:")
print("=" * 60)
def precision_attack(A, B):
return torch.matmul(A.half(), B.half())
all_passed, results, output = run_all_defenses(precision_attack, A, B)
for name, passed, msg in results:
status = "✓ PASS" if passed else "✗ FAIL"
print(f" [{status}] {name}: {msg}")
print(f"\n Overall: {'✓ ALL PASSED' if all_passed else '✗ SOME FAILED'}")
if __name__ == "__main__":
# test_thread_injection()
# test_stream_injection()
# test_lazy_evaluation()
# test_precision_downgrade()
# test_elapsed_time_monkey_patching()
test_run_all_defenses()