Skip to content

Commit f900a6c

Browse files
committed
Try to reduce jit id computation bottleneck
1 parent 44324cf commit f900a6c

File tree

1 file changed

+25
-4
lines changed

1 file changed

+25
-4
lines changed

palace-core/src/jit.rs

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use crate::{
1010
dim::{DDyn, DynDimension},
1111
dtypes::{AsDynType, DType, ElementType, ScalarType},
1212
op_descriptor,
13-
operator::{DataParam, OperatorDescriptor, OperatorParameter},
13+
operator::{DataParam, DataParamWithExternalId, OperatorDescriptor, OperatorParameter},
1414
operators::tensor::TensorOperator,
1515
storage::{
1616
gpu::{InplaceHandle, InplaceResult, WriteHandle},
@@ -488,12 +488,18 @@ impl<T: Identify> OrderedSet<T> {
488488
}
489489

490490
//TODO: Rc for cheap clone?
491-
#[derive(Identify, Clone)]
491+
#[derive(Clone)]
492492
pub struct JitTensorOperator<D: DynDimension> {
493493
metadata: Option<TensorMetaData<D>>,
494494
dtype: DType,
495495
operators: OrderedSet<TensorOperator<D, DType>>,
496496
instructions: Vec<Instruction>,
497+
id: Id,
498+
}
499+
impl<D: DynDimension> Identify for JitTensorOperator<D> {
500+
fn id(&self) -> Id {
501+
self.id
502+
}
497503
}
498504

499505
fn merge_instructions<D: DynDimension>(
@@ -584,6 +590,7 @@ impl<D: DynDimension> JitTensorOperator<D> {
584590
}
585591

586592
Ok({
593+
let id = Id::combine(&[op.id(), inner.id()]);
587594
let dtype = op.dtype(inner.dtype)?;
588595
let op = Instruction::Unary(dtype, op, InstructionOffset(1));
589596
let mut ops = inner.instructions;
@@ -593,6 +600,7 @@ impl<D: DynDimension> JitTensorOperator<D> {
593600
dtype,
594601
operators: inner.operators,
595602
instructions: ops,
603+
id,
596604
}
597605
})
598606
}
@@ -602,6 +610,7 @@ impl<D: DynDimension> JitTensorOperator<D> {
602610
r: JitTensorOperator<D>,
603611
) -> Result<Self, crate::Error> {
604612
Ok({
613+
let id = Id::combine(&[op.id(), l.id(), r.id()]);
605614
let dtype = op.dtype(l.dtype, r.dtype)?;
606615

607616
let (inputs, mut ops, offset_l, offset_r) =
@@ -620,6 +629,7 @@ impl<D: DynDimension> JitTensorOperator<D> {
620629
dtype,
621630
operators: inputs,
622631
instructions: ops,
632+
id,
623633
}
624634
})
625635
}
@@ -630,6 +640,7 @@ impl<D: DynDimension> JitTensorOperator<D> {
630640
a2: JitTensorOperator<D>,
631641
) -> Result<Self, crate::Error> {
632642
Ok({
643+
let id = Id::combine(&[op.id(), a0.id(), a1.id(), a2.id()]);
633644
let dtype = op.dtype(a0.dtype, a1.dtype, a2.dtype)?;
634645

635646
let (inputs_initial, ops_initial, offset_0_initial, _) =
@@ -656,6 +667,7 @@ impl<D: DynDimension> JitTensorOperator<D> {
656667
dtype,
657668
operators: inputs,
658669
instructions: ops,
670+
id,
659671
}
660672
})
661673
}
@@ -665,12 +677,14 @@ impl<D: DynDimension> From<ConstValue> for JitTensorOperator<D> {
665677
fn from(c: ConstValue) -> Self {
666678
let dtype = c.dtype();
667679
let op = Instruction::NullAry(c.dtype(), NullaryOp::Const(c));
680+
let id = op.id();
668681
let ops = vec![op];
669682
Self {
670683
instructions: ops,
671684
metadata: None,
672685
dtype,
673686
operators: OrderedSet(Vec::new()),
687+
id,
674688
}
675689
}
676690
}
@@ -695,24 +709,28 @@ pub fn const_vec<D: DynDimension, V: Into<ConstValue>>(value: V) -> JitTensorOpe
695709
pub fn dimensions<D: DynDimension>(d: D) -> JitTensorOperator<D> {
696710
let dtype = ScalarType::U32.vec(d.n() as _);
697711
let op = Instruction::NullAry(dtype, NullaryOp::Dimensions);
712+
let id = op.id();
698713
let ops = vec![op];
699714
JitTensorOperator {
700715
instructions: ops,
701716
metadata: None,
702717
dtype,
703718
operators: OrderedSet(Vec::new()),
719+
id,
704720
}
705721
}
706722

707723
pub fn position<D: DynDimension>(d: D) -> JitTensorOperator<D> {
708724
let dtype = ScalarType::U32.vec(d.n() as _);
709725
let op = Instruction::NullAry(dtype, NullaryOp::Position);
726+
let id = op.id();
710727
let ops = vec![op];
711728
JitTensorOperator {
712729
instructions: ops,
713730
metadata: None,
714731
dtype,
715732
operators: OrderedSet(Vec::new()),
733+
id,
716734
}
717735
}
718736

@@ -722,12 +740,15 @@ impl<D: DynDimension> From<TensorOperator<D, DType>> for JitTensorOperator<D> {
722740
let metadata = Some(c.metadata.clone());
723741

724742
let op = Instruction::NullAry(dtype, NullaryOp::Read(InputId(0)));
743+
let id = Id::combine(&[c.id(), op.id()]);
725744
let ops = vec![op];
745+
726746
Self {
727747
instructions: ops,
728748
metadata,
729749
dtype,
730750
operators: OrderedSet(vec![c]),
751+
id,
731752
}
732753
}
733754
}
@@ -1121,12 +1142,12 @@ impl<D: DynDimension> JitTensorOperator<D> {
11211142
let pipeline = device.request_state(
11221143
(
11231144
&input_dtypes,
1124-
&jit_operator.instructions,
1145+
DataParamWithExternalId(&jit_operator.instructions, jit_operator.id),
11251146
num_chunk_elements,
11261147
push_constants,
11271148
),
11281149
|device, (input_dtypes, instructions, num_chunk_elements, push_constants)| {
1129-
let (shader, config) = compile(instructions, input_dtypes)?;
1150+
let (shader, config) = compile(*instructions, input_dtypes)?;
11301151
//println!("{}", shader.as_str());
11311152
ComputePipelineBuilder::new(
11321153
Shader::new(shader.as_str())

0 commit comments

Comments
 (0)