22
22
MAX_CASES = 50
23
23
24
24
25
+ # Global cache to store generated shapes per tensor to ensure consistency
26
+ _shape_cache : dict [str , list [int ]] = {}
27
+
28
+
25
29
def apply_tensor_contraints (op_name : str , index : int ) -> list [object ]:
26
- # Constraint to limit tensor size product to < 4000 with fully randomized shapes
30
+ # Constraint to limit tensor size to < 4000 bytes with fully randomized shapes
27
31
import random
28
32
29
- # Global cache to store generated shapes per tensor to ensure consistency
30
- _shape_cache : dict [str , list [int ]] = {}
33
+ def get_dtype_bytes (dtype : torch .dtype ) -> int :
34
+ """Get the number of bytes per element for a given dtype"""
35
+ dtype_bytes = {
36
+ torch .int8 : 1 ,
37
+ torch .uint8 : 1 ,
38
+ torch .int16 : 2 ,
39
+ torch .uint16 : 2 ,
40
+ torch .int32 : 4 ,
41
+ torch .float32 : 4 ,
42
+ torch .int64 : 8 ,
43
+ torch .float64 : 8 ,
44
+ torch .bool : 1 ,
45
+ torch .float : 4 , # alias for float32
46
+ torch .int : 4 , # alias for int32
47
+ torch .long : 8 , # alias for int64
48
+ }
49
+ return dtype_bytes .get (dtype , 4 ) # Default to 4 bytes if dtype not found
31
50
32
- def generate_random_shape_with_product_limit (
33
- rank : int , max_product : int = 3999 , seed_base : int = 42
51
+ def generate_random_shape_with_byte_limit (
52
+ rank : int , dtype : torch . dtype , max_bytes : int = 3999 , seed_base : int = 42
34
53
) -> list [int ]:
35
- """Generate a random shape with given rank ensuring product < max_product """
54
+ """Generate a random shape with given rank ensuring total byte size < max_bytes """
36
55
random .seed (seed_base + rank )
37
56
57
+ bytes_per_element = get_dtype_bytes (dtype )
58
+ max_elements = max_bytes // bytes_per_element
59
+
38
60
# Start with all dimensions as 1
39
61
shape = [1 ] * rank
40
- remaining_product = max_product - 1 # Leave room since we start with product=1
62
+ remaining_elements = (
63
+ max_elements - 1
64
+ ) # Leave room since we start with product=1
41
65
42
66
# Randomly distribute the remaining capacity across dimensions
43
67
for i in range (rank ):
44
- if remaining_product <= 1 :
68
+ if remaining_elements <= 1 :
45
69
break
46
70
47
71
# Calculate maximum size this dimension can have without exceeding limit
@@ -51,28 +75,32 @@ def generate_random_shape_with_product_limit(
51
75
current_product *= shape [j ]
52
76
53
77
max_size_for_dim = min (
54
- remaining_product // current_product , 50
78
+ remaining_elements // current_product , 50
55
79
) # Cap at 50
56
80
if max_size_for_dim > shape [i ]:
57
81
# Randomly choose a size between current and max
58
82
new_size = random .randint (shape [i ], max_size_for_dim )
59
83
shape [i ] = new_size
60
- remaining_product = max_product // (current_product * new_size )
61
- remaining_product = max (1 , remaining_product )
84
+ remaining_elements = max_elements // (current_product * new_size )
85
+ remaining_elements = max (1 , remaining_elements )
62
86
63
87
# Final random shuffle of the dimensions to make it more random
64
88
random .shuffle (shape )
65
89
return shape
66
90
67
91
def random_size_constraint (deps : object , r : int , d : int ) -> int :
68
- """Generate random sizes ensuring total product < 4000"""
92
+ """Generate random sizes ensuring total byte size < 4000 bytes"""
93
+ # Use conservative approach: assume worst case is 4 bytes per element (float32/int32)
94
+ # This ensures we never exceed 4000 bytes regardless of actual dtype
95
+ worst_case_dtype = torch .float32 # 4 bytes per element
96
+
69
97
# Create a unique key for this tensor configuration
70
- cache_key = f"{ r } _{ d } "
98
+ cache_key = f"{ r } _{ d } _conservative "
71
99
72
100
if cache_key not in _shape_cache :
73
- # Generate a new random shape for this rank
74
- shape = generate_random_shape_with_product_limit (
75
- r , max_product = 3999 , seed_base = 42 + r * 10
101
+ # Generate a new random shape for this rank using worst-case byte estimation
102
+ shape = generate_random_shape_with_byte_limit (
103
+ r , worst_case_dtype , max_bytes = 3999 , seed_base = 42 + r * 10 + d
76
104
)
77
105
_shape_cache [cache_key ] = shape
78
106
0 commit comments