|
| 1 | +#Experimental eager cache eviction |
| 2 | +#Expect things to break if any dynamic prompts are used |
| 3 | +import functools |
| 4 | +from comfy_execution.caching import HierarchicalCache, CacheKeySetInputSignature, CacheKeySetID |
| 5 | +from comfy_execution import graph |
| 6 | +import execution |
| 7 | + |
| 8 | +def is_link(inp): |
| 9 | + return isinstance(inp, list) and len(inp) == 2 |
| 10 | +def link_count(dynprompt, node_id): |
| 11 | + return sum([is_link(x) for x in dynprompt.get_node(node_id)['inputs'].values()]) |
| 12 | + |
| 13 | +class MinCache(HierarchicalCache): |
| 14 | + def set_prompt(self, dynprompt, node_ids, is_changed_cache): |
| 15 | + super().set_prompt(dynprompt, node_ids, is_changed_cache) |
| 16 | + self.dependents = {} |
| 17 | + for node_id in node_ids: |
| 18 | + inputs = dynprompt.get_node(node_id)['inputs'] |
| 19 | + for inp in inputs.values(): |
| 20 | + if isinstance(inp, list) and len(inp) == 2: |
| 21 | + if inp[0] not in self.dependents: |
| 22 | + self.dependents[inp[0]] = [] |
| 23 | + self.dependents[inp[0]].append(node_id) |
| 24 | + def set(self, node_id, value): |
| 25 | + super().set(node_id, value) |
| 26 | + inputs = self.dynprompt.get_node(node_id)['inputs'] |
| 27 | + for inp in inputs.values(): |
| 28 | + if not is_link(inp): |
| 29 | + continue |
| 30 | + input_id = inp[0] |
| 31 | + self.dependents[input_id].remove(node_id) |
| 32 | + if len(self.dependents[input_id]) == 0: |
| 33 | + cache_key = self.cache_key_set.get_data_key(input_id) |
| 34 | + del self.cache[cache_key] |
| 35 | + |
| 36 | +def init_cache(self): |
| 37 | + self.outputs = MinCache(CacheKeySetInputSignature) |
| 38 | + self.ui = HierarchicalCache(CacheKeySetInputSignature) |
| 39 | + self.objects = HierarchicalCache(CacheKeySetID) |
| 40 | +execution.CacheSet.init_classic_cache = init_cache |
| 41 | + |
| 42 | +class MincacheExecutionList(graph.ExecutionList): |
| 43 | + def __init__(self, *args, **kwargs): |
| 44 | + print('init') |
| 45 | + super().__init__(*args, **kwargs) |
| 46 | + self.depth = {} |
| 47 | + def stage_node_execution(self): |
| 48 | + assert self.staged_node_id is None |
| 49 | + if self.is_empty(): |
| 50 | + return None, None, None |
| 51 | + available = self.get_ready_nodes() |
| 52 | + if len(available) == 0: |
| 53 | + #aint got time for this |
| 54 | + return super().stage_node_execution() |
| 55 | + available.sort(key=lambda x: (-link_count(self.dynprompt, x), |
| 56 | + -self.depth.get(x,0), |
| 57 | + len(self.blocking[x]), x)) |
| 58 | + print([self.dynprompt.get_node(x)['class_type'] for x in available]) |
| 59 | + self.staged_node_id = available[0] |
| 60 | + return self.staged_node_id, None, None |
| 61 | + def add_strong_link(self, from_node_id, from_socket, to_node_id): |
| 62 | + super().add_strong_link(from_node_id, from_socket, to_node_id) |
| 63 | + self.depth[from_node_id] = max(self.depth.get(to_node_id, 0) + 1, |
| 64 | + self.depth.get(from_node_id, 0)) |
| 65 | +execution.ExecutionList = MincacheExecutionList |
| 66 | + |
| 67 | +''' |
| 68 | +Prioritize |
| 69 | +- A computation that allows clearing a cached result |
| 70 | +- A computation that progresses towards clearing a cached item |
| 71 | +- A computation that is of the greatest depth for cached items |
| 72 | + - depth is 1+max(0, *dependent_depths) |
| 73 | +
|
| 74 | +sort nodes by tuple (-num_cached_dependencies, uncached_dependencies (always 0?), -depth) |
| 75 | +''' |
0 commit comments