Skip to content

Commit 7d5c886

Browse files
authored
limit facto to 4000 bytes than numel
Differential Revision: D82483935 Pull Request resolved: #14318
1 parent 75cb986 commit 7d5c886

File tree

1 file changed

+44
-16
lines changed

1 file changed

+44
-16
lines changed

backends/cadence/utils/facto_util.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,26 +22,50 @@
2222
MAX_CASES = 50
2323

2424

25+
# Global cache to store generated shapes per tensor to ensure consistency
26+
_shape_cache: dict[str, list[int]] = {}
27+
28+
2529
def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
26-
# Constraint to limit tensor size product to < 4000 with fully randomized shapes
30+
# Constraint to limit tensor size to < 4000 bytes with fully randomized shapes
2731
import random
2832

29-
# Global cache to store generated shapes per tensor to ensure consistency
30-
_shape_cache: dict[str, list[int]] = {}
33+
def get_dtype_bytes(dtype: torch.dtype) -> int:
34+
"""Get the number of bytes per element for a given dtype"""
35+
dtype_bytes = {
36+
torch.int8: 1,
37+
torch.uint8: 1,
38+
torch.int16: 2,
39+
torch.uint16: 2,
40+
torch.int32: 4,
41+
torch.float32: 4,
42+
torch.int64: 8,
43+
torch.float64: 8,
44+
torch.bool: 1,
45+
torch.float: 4, # alias for float32
46+
torch.int: 4, # alias for int32
47+
torch.long: 8, # alias for int64
48+
}
49+
return dtype_bytes.get(dtype, 4) # Default to 4 bytes if dtype not found
3150

32-
def generate_random_shape_with_product_limit(
33-
rank: int, max_product: int = 3999, seed_base: int = 42
51+
def generate_random_shape_with_byte_limit(
52+
rank: int, dtype: torch.dtype, max_bytes: int = 3999, seed_base: int = 42
3453
) -> list[int]:
35-
"""Generate a random shape with given rank ensuring product < max_product"""
54+
"""Generate a random shape with given rank ensuring total byte size < max_bytes"""
3655
random.seed(seed_base + rank)
3756

57+
bytes_per_element = get_dtype_bytes(dtype)
58+
max_elements = max_bytes // bytes_per_element
59+
3860
# Start with all dimensions as 1
3961
shape = [1] * rank
40-
remaining_product = max_product - 1 # Leave room since we start with product=1
62+
remaining_elements = (
63+
max_elements - 1
64+
) # Leave room since we start with product=1
4165

4266
# Randomly distribute the remaining capacity across dimensions
4367
for i in range(rank):
44-
if remaining_product <= 1:
68+
if remaining_elements <= 1:
4569
break
4670

4771
# Calculate maximum size this dimension can have without exceeding limit
@@ -51,28 +75,32 @@ def generate_random_shape_with_product_limit(
5175
current_product *= shape[j]
5276

5377
max_size_for_dim = min(
54-
remaining_product // current_product, 50
78+
remaining_elements // current_product, 50
5579
) # Cap at 50
5680
if max_size_for_dim > shape[i]:
5781
# Randomly choose a size between current and max
5882
new_size = random.randint(shape[i], max_size_for_dim)
5983
shape[i] = new_size
60-
remaining_product = max_product // (current_product * new_size)
61-
remaining_product = max(1, remaining_product)
84+
remaining_elements = max_elements // (current_product * new_size)
85+
remaining_elements = max(1, remaining_elements)
6286

6387
# Final random shuffle of the dimensions to make it more random
6488
random.shuffle(shape)
6589
return shape
6690

6791
def random_size_constraint(deps: object, r: int, d: int) -> int:
68-
"""Generate random sizes ensuring total product < 4000"""
92+
"""Generate random sizes ensuring total byte size < 4000 bytes"""
93+
# Use conservative approach: assume worst case is 4 bytes per element (float32/int32)
94+
# This ensures we never exceed 4000 bytes regardless of actual dtype
95+
worst_case_dtype = torch.float32 # 4 bytes per element
96+
6997
# Create a unique key for this tensor configuration
70-
cache_key = f"{r}_{d}"
98+
cache_key = f"{r}_{d}_conservative"
7199

72100
if cache_key not in _shape_cache:
73-
# Generate a new random shape for this rank
74-
shape = generate_random_shape_with_product_limit(
75-
r, max_product=3999, seed_base=42 + r * 10
101+
# Generate a new random shape for this rank using worst-case byte estimation
102+
shape = generate_random_shape_with_byte_limit(
103+
r, worst_case_dtype, max_bytes=3999, seed_base=42 + r * 10 + d
76104
)
77105
_shape_cache[cache_key] = shape
78106

0 commit comments

Comments
 (0)