@@ -323,13 +323,144 @@ where
323323 format ! ( "{}" , petgraph:: dot:: Dot :: new( g) )
324324}
325325
326- pub trait LazyBuffer : Debug + Clone { }
326+ pub trait LazyBuffer : Debug + Clone + PartialEq { }
327327
328328pub trait LazyAllocator < B : LazyBuffer > {
329+ fn initialize (
330+ & mut self ,
331+ graph : & mut OpGraph ,
332+ edges : & [ EdgeIndex ] ,
333+ last_node : NodeIndex ,
334+ ) -> Result < ( ) > ;
329335 fn insert ( & mut self , id : BufferId , buffer : B ) -> Result < ( ) > ;
336+ fn get ( & self , id : BufferId ) -> Option < & B > ;
330337 fn get_or_allocate ( & mut self , id : BufferId , shape : & Shape , dtype : DType ) -> Result < B > ;
331338}
332339
340+ pub fn determine_tensor_source < ' a > ( graph : & ' a OpGraph , edge : & ' a Edge < OpEdge > ) -> & ' a Edge < OpEdge > {
341+ let mut source = edge;
342+ loop {
343+ let next_edge = source. next_edge ( petgraph:: Incoming ) ;
344+ if next_edge == EdgeIndex :: end ( ) {
345+ break ;
346+ }
347+ source = & graph. raw_edges ( ) [ next_edge. index ( ) ] ;
348+ }
349+ source
350+ }
351+
352+ pub fn calculate_usage_records (
353+ graph : & OpGraph ,
354+ edges : & [ EdgeIndex ] ,
355+ ) -> HashMap < BufferId , ( Option < BufferId > , Option < usize > , usize , Shape , DType ) > {
356+ let mut records = HashMap :: with_capacity ( edges. len ( ) ) ;
357+ let topo_len = edges. len ( ) - 1 ;
358+ for ( i, edge_idx) in edges. iter ( ) . rev ( ) . enumerate ( ) {
359+ let edge = & graph. raw_edges ( ) [ edge_idx. index ( ) ] ;
360+ let buffer_id = edge. weight . buffer_id ( ) ;
361+ let node_idx = edge. source ( ) ;
362+
363+ let t = & graph[ node_idx] ;
364+ if t. resolved ( ) {
365+ continue ;
366+ }
367+ let incoming = graph. edges_directed ( node_idx, petgraph:: Incoming ) ;
368+ for in_idx in incoming {
369+ let in_edge: & Edge < OpEdge > = & graph. raw_edges ( ) [ in_idx. id ( ) . index ( ) ] ;
370+ let source_idx = in_edge. source ( ) ;
371+
372+ let source = & graph[ source_idx] ;
373+ if source. resolved ( ) {
374+ continue ;
375+ }
376+ let true_source = determine_tensor_source ( graph, in_edge) ;
377+ records
378+ . entry ( true_source. weight . buffer_id ( ) )
379+ . or_insert_with ( || {
380+ (
381+ None ,
382+ None ,
383+ topo_len - i,
384+ true_source. weight . shape ( ) . clone ( ) ,
385+ true_source. weight . dtype ( ) ,
386+ )
387+ } ) ;
388+ }
389+
390+ if let Some ( record) = records. get_mut ( & edge. weight . buffer_id ( ) ) {
391+ record. 0 = Some ( edge. weight . buffer_id ( ) ) ;
392+ record. 1 = Some ( topo_len - i) ;
393+ }
394+ }
395+ //filter records with no producer
396+ records. retain ( |_, v| v. 1 . is_some ( ) ) ;
397+ records
398+ }
399+ pub fn greedy_by_size < A : LazyAllocator < B > , B : LazyBuffer > (
400+ graph : & OpGraph ,
401+ edges : & [ EdgeIndex ] ,
402+ allocator : & mut A ,
403+ ) -> Result < ( ) > {
404+ let record_map = calculate_usage_records ( graph, edges) ;
405+ let mut shared_objects: Vec < B > = Vec :: with_capacity ( record_map. len ( ) ) ;
406+
407+ for ( buffer_id, ( record_buffer_id, producer, last_consumer, shape, dtype) ) in record_map. iter ( )
408+ {
409+ let record_producer = producer. unwrap ( ) ;
410+ let mut best_obj = None ;
411+ for obj in shared_objects. iter ( ) {
412+ let mut suitable = true ;
413+ for (
414+ inner_buffer_id,
415+ ( _, inner_producer, inner_last_consumer, inner_shape, inner_dtype) ,
416+ ) in record_map. iter ( )
417+ {
418+ let max_first = std:: cmp:: max ( record_producer, inner_producer. unwrap ( ) ) ;
419+ let min_last = * std:: cmp:: min ( last_consumer, inner_last_consumer) ;
420+ if max_first <= min_last && allocator. get ( * inner_buffer_id) == Some ( obj) {
421+ suitable = false ;
422+ break ;
423+ }
424+ }
425+ if suitable {
426+ best_obj = Some ( obj) ;
427+ }
428+ }
429+ if let Some ( obj) = best_obj {
430+ allocator. insert ( * buffer_id, ( * obj) . clone ( ) ) ?;
431+ } else {
432+ //let rounded_size = (record.size - 1).next_power_of_two();
433+ let buffer = allocator. get_or_allocate ( * buffer_id, shape, * dtype) ?;
434+ shared_objects. push ( buffer. clone ( ) ) ;
435+ }
436+ }
437+
438+ //Loop through and add inplace assignments
439+ for edge_idx in edges. iter ( ) {
440+ let edge = & graph. raw_edges ( ) [ edge_idx. index ( ) ] ;
441+ let node_idx = edge. source ( ) ;
442+ let t = & graph[ node_idx] ;
443+ if t. resolved ( ) {
444+ continue ;
445+ }
446+ let incoming = graph. edges_directed ( node_idx, petgraph:: Incoming ) ;
447+ for in_idx in incoming {
448+ let in_edge: & Edge < OpEdge > = & graph. raw_edges ( ) [ in_idx. id ( ) . index ( ) ] ;
449+
450+ let true_source = determine_tensor_source ( graph, in_edge) ;
451+ if true_source. weight . buffer_id ( ) != in_edge. weight . buffer_id ( ) {
452+ if let Some ( buf) = allocator. get ( true_source. weight . buffer_id ( ) ) {
453+ allocator. insert ( in_edge. weight . buffer_id ( ) , buf. clone ( ) ) ?;
454+ }
455+ }
456+ }
457+ }
458+
459+ //We use `immediate` = false here in create_buffer
460+ //and submit the queue after all allocations are done.
461+ Ok ( ( ) )
462+ }
463+
333464pub trait Executor {
334465 type BufferType : LazyBuffer ;
335466 type AllocatorType : LazyAllocator < Self :: BufferType > ;
@@ -349,6 +480,8 @@ pub trait Executor {
349480 allocator : & mut Self :: AllocatorType ,
350481 node : NodeIndex ,
351482 ) -> Result < ( ) > ;
483+
484+ fn allocator ( & self ) -> Self :: AllocatorType ;
352485}
353486
354487impl LazyStorage {
@@ -387,17 +520,29 @@ impl OpEdge {
387520 }
388521 }
389522
523+ pub fn id ( & self ) -> EdgeId {
524+ self . id
525+ }
526+
390527 pub fn layout ( & self ) -> & Layout {
391528 & self . layout
392529 }
393530
531+ pub fn shape ( & self ) -> & Shape {
532+ & self . layout . shape ( )
533+ }
534+
394535 pub fn dtype ( & self ) -> DType {
395536 self . dtype
396537 }
397538
398539 pub fn buffer_id ( & self ) -> BufferId {
399540 self . buffer_id
400541 }
542+
543+ pub fn bytes ( & self ) -> usize {
544+ self . layout . shape ( ) . elem_count ( ) * self . dtype . size_in_bytes ( )
545+ }
401546}
402547
403548#[ derive( Debug , Clone , PartialEq ) ]
@@ -421,6 +566,10 @@ impl OpNode {
421566 pub fn op ( & self ) -> & Op {
422567 & self . op
423568 }
569+
570+ pub fn resolved ( & self ) -> bool {
571+ matches ! ( self . op, Op :: Const ( _) )
572+ }
424573}
425574
426575#[ derive( Debug , Clone , PartialEq ) ]
@@ -769,15 +918,15 @@ impl BackendStorage for LazyStorage {
769918 let idx = next. add_operation ( Op :: WhereCond ) ;
770919
771920 let current_op = next. get_current_node ( ) ?;
772- let lhs_edge = OpEdge :: new ( l. clone ( ) , self . dtype ( ) ) ;
773- next. operations . add_edge ( current_op, idx, lhs_edge ) ;
921+ let src = OpEdge :: new ( l. clone ( ) , self . dtype ( ) ) ;
922+ next. operations . add_edge ( current_op, idx, src ) ;
774923
775924 let t_op = t. get_current_node ( ) ?;
776925 let t_edge = OpEdge :: new ( t_l. clone ( ) , t. dtype ( ) ) ;
777926 next. merge ( t, t_op, idx, t_edge) ?;
778927
779928 let f_op = f. get_current_node ( ) ?;
780- let f_edge = OpEdge :: new ( f_l. clone ( ) , t . dtype ( ) ) ;
929+ let f_edge = OpEdge :: new ( f_l. clone ( ) , f . dtype ( ) ) ;
781930 next. merge ( f, f_op, idx, f_edge) ?;
782931
783932 next. current_node = Some ( idx) ;
@@ -1052,7 +1201,7 @@ impl BackendStorage for LazyStorage {
10521201 let idx = next. add_operation ( op) ;
10531202
10541203 let current_op = next. get_current_node ( ) ?;
1055- let edge = OpEdge :: new ( src_l. clone ( ) , self . dtype ( ) ) ;
1204+ let edge = OpEdge :: new ( src_l. clone ( ) , ids . dtype ( ) ) ;
10561205 next. operations . add_edge ( current_op, idx, edge) ;
10571206
10581207 let ids_op = ids. get_current_node ( ) ?;
0 commit comments