2222MAX_CASES = 50
2323
2424
25+ # Global cache to store generated shapes per tensor to ensure consistency
26+ _shape_cache : dict [str , list [int ]] = {}
27+
28+
2529def apply_tensor_contraints (op_name : str , index : int ) -> list [object ]:
26- # Constraint to limit tensor size product to < 4000 with fully randomized shapes
30+ # Constraint to limit tensor size to < 4000 bytes with fully randomized shapes
2731 import random
2832
29- # Global cache to store generated shapes per tensor to ensure consistency
30- _shape_cache : dict [str , list [int ]] = {}
33+ def get_dtype_bytes (dtype : torch .dtype ) -> int :
34+ """Get the number of bytes per element for a given dtype"""
35+ dtype_bytes = {
36+ torch .int8 : 1 ,
37+ torch .uint8 : 1 ,
38+ torch .int16 : 2 ,
39+ torch .uint16 : 2 ,
40+ torch .int32 : 4 ,
41+ torch .float32 : 4 ,
42+ torch .int64 : 8 ,
43+ torch .float64 : 8 ,
44+ torch .bool : 1 ,
45+ torch .float : 4 , # alias for float32
46+ torch .int : 4 , # alias for int32
47+ torch .long : 8 , # alias for int64
48+ }
49+ return dtype_bytes .get (dtype , 4 ) # Default to 4 bytes if dtype not found
3150
32- def generate_random_shape_with_product_limit (
33- rank : int , max_product : int = 3999 , seed_base : int = 42
51+ def generate_random_shape_with_byte_limit (
52+ rank : int , dtype : torch . dtype , max_bytes : int = 3999 , seed_base : int = 42
3453 ) -> list [int ]:
35- """Generate a random shape with given rank ensuring product < max_product """
54+ """Generate a random shape with given rank ensuring total byte size < max_bytes """
3655 random .seed (seed_base + rank )
3756
57+ bytes_per_element = get_dtype_bytes (dtype )
58+ max_elements = max_bytes // bytes_per_element
59+
3860 # Start with all dimensions as 1
3961 shape = [1 ] * rank
40- remaining_product = max_product - 1 # Leave room since we start with product=1
62+ remaining_elements = (
63+ max_elements - 1
64+ ) # Leave room since we start with product=1
4165
4266 # Randomly distribute the remaining capacity across dimensions
4367 for i in range (rank ):
44- if remaining_product <= 1 :
68+ if remaining_elements <= 1 :
4569 break
4670
4771 # Calculate maximum size this dimension can have without exceeding limit
@@ -51,28 +75,32 @@ def generate_random_shape_with_product_limit(
5175 current_product *= shape [j ]
5276
5377 max_size_for_dim = min (
54- remaining_product // current_product , 50
78+ remaining_elements // current_product , 50
5579 ) # Cap at 50
5680 if max_size_for_dim > shape [i ]:
5781 # Randomly choose a size between current and max
5882 new_size = random .randint (shape [i ], max_size_for_dim )
5983 shape [i ] = new_size
60- remaining_product = max_product // (current_product * new_size )
61- remaining_product = max (1 , remaining_product )
84+ remaining_elements = max_elements // (current_product * new_size )
85+ remaining_elements = max (1 , remaining_elements )
6286
6387 # Final random shuffle of the dimensions to make it more random
6488 random .shuffle (shape )
6589 return shape
6690
6791 def random_size_constraint (deps : object , r : int , d : int ) -> int :
68- """Generate random sizes ensuring total product < 4000"""
92+ """Generate random sizes ensuring total byte size < 4000 bytes"""
93+ # Use conservative approach: assume worst case is 4 bytes per element (float32/int32)
94+ # This ensures we never exceed 4000 bytes regardless of actual dtype
95+ worst_case_dtype = torch .float32 # 4 bytes per element
96+
6997 # Create a unique key for this tensor configuration
70- cache_key = f"{ r } _{ d } "
98+ cache_key = f"{ r } _{ d } _conservative "
7199
72100 if cache_key not in _shape_cache :
73- # Generate a new random shape for this rank
74- shape = generate_random_shape_with_product_limit (
75- r , max_product = 3999 , seed_base = 42 + r * 10
101+ # Generate a new random shape for this rank using worst-case byte estimation
102+ shape = generate_random_shape_with_byte_limit (
103+ r , worst_case_dtype , max_bytes = 3999 , seed_base = 42 + r * 10 + d
76104 )
77105 _shape_cache [cache_key ] = shape
78106
0 commit comments