Skip to content

Commit 409b356

Browse files
Lower complexity of get_balanced_memory by adding a set (#3776)
* Lower complexity by adding a set * Push vibe coded eval script * Clean
1 parent 1b50d93 commit 409b356

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/accelerate/utils/modeling.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1041,7 +1041,8 @@ def get_balanced_memory(
10411041

10421042
# Compute mean of final modules. In the first dict of module sizes, leaves are the parameters
10431043
leaves = get_module_leaves(module_sizes)
1044-
module_sizes = {n: v for n, v in module_sizes.items() if n not in leaves}
1044+
leaves_set = set(leaves) # Convert to set for O(1) membership testing
1045+
module_sizes = {n: v for n, v in module_sizes.items() if n not in leaves_set}
10451046
# Once removed, leaves are the final modules.
10461047
leaves = get_module_leaves(module_sizes)
10471048
mean_leaves = int(sum([module_sizes[n] for n in leaves]) / max(len(leaves), 1))

0 commit comments

Comments
 (0)