@@ -207,19 +207,22 @@ async def _collect_input_sizes(
207207 * (storage_api .get_infos .delay (k ) for k in fetch_keys )
208208 )
209209
210+ # compute memory quota size. when data located in shared memory, the cost
211+ # should be differences between deserialized memory cost and serialized cost,
212+ # otherwise we should take deserialized memory cost
210213 for key , meta , infos in zip (fetch_keys , fetch_metas , data_infos ):
211214 level = functools .reduce (operator .or_ , (info .level for info in infos ))
212215 if level & StorageLevel .MEMORY :
213216 mem_cost = max (0 , meta ["memory_size" ] - meta ["store_size" ])
214217 else :
215218 mem_cost = meta ["memory_size" ]
216- sizes [key ] = (mem_cost , mem_cost )
219+ sizes [key ] = (meta [ "store_size" ] , mem_cost )
217220
218221 return sizes
219222
220223 @classmethod
221224 def _estimate_sizes (cls , subtask : Subtask , input_sizes : Dict ):
222- size_context = { k : ( s , 0 ) for k , ( s , _c ) in input_sizes .items ()}
225+ size_context = dict ( input_sizes .items ())
223226 graph = subtask .chunk_graph
224227
225228 key_to_ops = defaultdict (set )
@@ -243,7 +246,7 @@ def _estimate_sizes(cls, subtask: Subtask, input_sizes: Dict):
243246
244247 visited_op_keys = set ()
245248 total_memory_cost = 0
246- max_memory_cost = 0
249+ max_memory_cost = sum ( calc_size for _ , calc_size in size_context . values ())
247250 while key_stack :
248251 key = key_stack .pop ()
249252 op = key_to_ops [key ][0 ]
@@ -255,24 +258,31 @@ def _estimate_sizes(cls, subtask: Subtask, input_sizes: Dict):
255258 total_memory_cost += calc_cost
256259 max_memory_cost = max (total_memory_cost , max_memory_cost )
257260
258- result_cost = sum (size_context [out .key ][0 ] for out in op .outputs )
259- total_memory_cost += result_cost - calc_cost
261+ if not isinstance (op , Fetch ):
262+ # when calculation result is stored, memory cost of calculation
263+ # can be replaced with result memory cost
264+ result_cost = sum (size_context [out .key ][0 ] for out in op .outputs )
265+ total_memory_cost += result_cost - calc_cost
260266
261- visited_op_keys .add (op . key )
267+ visited_op_keys .add (key )
262268
263269 for succ_op_key in op_key_graph .iter_successors (key ):
264270 pred_ref_count [succ_op_key ] -= 1
265271 if pred_ref_count [succ_op_key ] == 0 :
266272 key_stack .append (succ_op_key )
273+
267274 for pred_op_key in op_key_graph .iter_predecessors (key ):
268275 succ_ref_count [pred_op_key ] -= 1
269276 if succ_ref_count [pred_op_key ] == 0 :
277+ pred_op = key_to_ops [pred_op_key ][0 ]
278+ # when clearing fetches, subtract memory size, otherwise subtract store size
279+ account_idx = 1 if isinstance (pred_op , Fetch ) else 0
270280 pop_result_cost = sum (
271- size_context .pop (out .key , (0 , 0 ))[0 ]
281+ size_context .pop (out .key , (0 , 0 ))[account_idx ]
272282 for out in key_to_ops [pred_op_key ][0 ].outputs
273283 )
274284 total_memory_cost -= pop_result_cost
275- return sum (t [1 ] for t in size_context .values ()), max_memory_cost
285+ return sum (t [0 ] for t in size_context .values ()), max_memory_cost
276286
277287 @classmethod
278288 def _check_cancelling (cls , subtask_info : SubtaskExecutionInfo ):
0 commit comments