@@ -10,7 +10,7 @@ use crate::{
1010 dim:: { DDyn , DynDimension } ,
1111 dtypes:: { AsDynType , DType , ElementType , ScalarType } ,
1212 op_descriptor,
13- operator:: { DataParam , OperatorDescriptor , OperatorParameter } ,
13+ operator:: { DataParam , DataParamWithExternalId , OperatorDescriptor , OperatorParameter } ,
1414 operators:: tensor:: TensorOperator ,
1515 storage:: {
1616 gpu:: { InplaceHandle , InplaceResult , WriteHandle } ,
@@ -488,12 +488,18 @@ impl<T: Identify> OrderedSet<T> {
488488}
489489
490490//TODO: Rc for cheap clone?
491- #[ derive( Identify , Clone ) ]
491+ #[ derive( Clone ) ]
492492pub struct JitTensorOperator < D : DynDimension > {
493493 metadata : Option < TensorMetaData < D > > ,
494494 dtype : DType ,
495495 operators : OrderedSet < TensorOperator < D , DType > > ,
496496 instructions : Vec < Instruction > ,
497+ id : Id ,
498+ }
499+ impl < D : DynDimension > Identify for JitTensorOperator < D > {
500+ fn id ( & self ) -> Id {
501+ self . id
502+ }
497503}
498504
499505fn merge_instructions < D : DynDimension > (
@@ -584,6 +590,7 @@ impl<D: DynDimension> JitTensorOperator<D> {
584590 }
585591
586592 Ok ( {
593+ let id = Id :: combine ( & [ op. id ( ) , inner. id ( ) ] ) ;
587594 let dtype = op. dtype ( inner. dtype ) ?;
588595 let op = Instruction :: Unary ( dtype, op, InstructionOffset ( 1 ) ) ;
589596 let mut ops = inner. instructions ;
@@ -593,6 +600,7 @@ impl<D: DynDimension> JitTensorOperator<D> {
593600 dtype,
594601 operators : inner. operators ,
595602 instructions : ops,
603+ id,
596604 }
597605 } )
598606 }
@@ -602,6 +610,7 @@ impl<D: DynDimension> JitTensorOperator<D> {
602610 r : JitTensorOperator < D > ,
603611 ) -> Result < Self , crate :: Error > {
604612 Ok ( {
613+ let id = Id :: combine ( & [ op. id ( ) , l. id ( ) , r. id ( ) ] ) ;
605614 let dtype = op. dtype ( l. dtype , r. dtype ) ?;
606615
607616 let ( inputs, mut ops, offset_l, offset_r) =
@@ -620,6 +629,7 @@ impl<D: DynDimension> JitTensorOperator<D> {
620629 dtype,
621630 operators : inputs,
622631 instructions : ops,
632+ id,
623633 }
624634 } )
625635 }
@@ -630,6 +640,7 @@ impl<D: DynDimension> JitTensorOperator<D> {
630640 a2 : JitTensorOperator < D > ,
631641 ) -> Result < Self , crate :: Error > {
632642 Ok ( {
643+ let id = Id :: combine ( & [ op. id ( ) , a0. id ( ) , a1. id ( ) , a2. id ( ) ] ) ;
633644 let dtype = op. dtype ( a0. dtype , a1. dtype , a2. dtype ) ?;
634645
635646 let ( inputs_initial, ops_initial, offset_0_initial, _) =
@@ -656,6 +667,7 @@ impl<D: DynDimension> JitTensorOperator<D> {
656667 dtype,
657668 operators : inputs,
658669 instructions : ops,
670+ id,
659671 }
660672 } )
661673 }
@@ -665,12 +677,14 @@ impl<D: DynDimension> From<ConstValue> for JitTensorOperator<D> {
665677 fn from ( c : ConstValue ) -> Self {
666678 let dtype = c. dtype ( ) ;
667679 let op = Instruction :: NullAry ( c. dtype ( ) , NullaryOp :: Const ( c) ) ;
680+ let id = op. id ( ) ;
668681 let ops = vec ! [ op] ;
669682 Self {
670683 instructions : ops,
671684 metadata : None ,
672685 dtype,
673686 operators : OrderedSet ( Vec :: new ( ) ) ,
687+ id,
674688 }
675689 }
676690}
@@ -695,24 +709,28 @@ pub fn const_vec<D: DynDimension, V: Into<ConstValue>>(value: V) -> JitTensorOpe
695709pub fn dimensions < D : DynDimension > ( d : D ) -> JitTensorOperator < D > {
696710 let dtype = ScalarType :: U32 . vec ( d. n ( ) as _ ) ;
697711 let op = Instruction :: NullAry ( dtype, NullaryOp :: Dimensions ) ;
712+ let id = op. id ( ) ;
698713 let ops = vec ! [ op] ;
699714 JitTensorOperator {
700715 instructions : ops,
701716 metadata : None ,
702717 dtype,
703718 operators : OrderedSet ( Vec :: new ( ) ) ,
719+ id,
704720 }
705721}
706722
707723pub fn position < D : DynDimension > ( d : D ) -> JitTensorOperator < D > {
708724 let dtype = ScalarType :: U32 . vec ( d. n ( ) as _ ) ;
709725 let op = Instruction :: NullAry ( dtype, NullaryOp :: Position ) ;
726+ let id = op. id ( ) ;
710727 let ops = vec ! [ op] ;
711728 JitTensorOperator {
712729 instructions : ops,
713730 metadata : None ,
714731 dtype,
715732 operators : OrderedSet ( Vec :: new ( ) ) ,
733+ id,
716734 }
717735}
718736
@@ -722,12 +740,15 @@ impl<D: DynDimension> From<TensorOperator<D, DType>> for JitTensorOperator<D> {
722740 let metadata = Some ( c. metadata . clone ( ) ) ;
723741
724742 let op = Instruction :: NullAry ( dtype, NullaryOp :: Read ( InputId ( 0 ) ) ) ;
743+ let id = Id :: combine ( & [ c. id ( ) , op. id ( ) ] ) ;
725744 let ops = vec ! [ op] ;
745+
726746 Self {
727747 instructions : ops,
728748 metadata,
729749 dtype,
730750 operators : OrderedSet ( vec ! [ c] ) ,
751+ id,
731752 }
732753 }
733754}
@@ -1121,12 +1142,12 @@ impl<D: DynDimension> JitTensorOperator<D> {
11211142 let pipeline = device. request_state (
11221143 (
11231144 & input_dtypes,
1124- & jit_operator. instructions ,
1145+ DataParamWithExternalId ( & jit_operator. instructions , jit_operator . id ) ,
11251146 num_chunk_elements,
11261147 push_constants,
11271148 ) ,
11281149 |device, ( input_dtypes, instructions, num_chunk_elements, push_constants) | {
1129- let ( shader, config) = compile ( instructions, input_dtypes) ?;
1150+ let ( shader, config) = compile ( * instructions, input_dtypes) ?;
11301151 //println!("{}", shader.as_str());
11311152 ComputePipelineBuilder :: new (
11321153 Shader :: new ( shader. as_str ( ) )
0 commit comments