Skip to content

Commit 5ea1cf3

Browse files
committed
Maybe fix build on GPU
1 parent c93da50 commit 5ea1cf3

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

openvm/src/trace_generation.rs

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ use tracing::info_span;
1313
use crate::{BabyBearSC, CompiledProgram, SpecializedConfigCpuBuilder};
1414

1515
#[cfg(not(feature = "cuda"))]
16-
use crate::PowdrSdkCpu;
16+
use crate::PowdrSdkCpu as PowdrSdk;
1717
#[cfg(feature = "cuda")]
18-
use crate::PowdrSdkGpu;
18+
use crate::PowdrSdkGpu as PowdrSdk;
1919

20+
#[cfg(feature = "cuda")]
21+
use openvm_cuda_backend::engine::GpuBabyBearPoseidon2Engine as BabyBearPoseidon2Engine;
2022
#[cfg(not(feature = "cuda"))]
2123
use openvm_stark_sdk::config::baby_bear_poseidon2::BabyBearPoseidon2Engine;
22-
#[cfg(feature = "cuda")]
23-
use openvm_stark_sdk::config::gpu_baby_bear_poseidon2::GpuBabyBearPoseidon2Engine;
2424

2525
/// Given a program and input, generates the trace segment by segment and calls the provided
2626
/// callback with the VM, proving key, and proving context (containing the trace) for each segment.
@@ -42,10 +42,7 @@ pub fn do_with_trace(
4242
let app_config = AppConfig::new(app_fri_params, vm_config.clone());
4343

4444
// Create the SDK
45-
#[cfg(feature = "cuda")]
46-
let sdk = PowdrSdkGpu::new(app_config).unwrap();
47-
#[cfg(not(feature = "cuda"))]
48-
let sdk = PowdrSdkCpu::new(app_config).unwrap();
45+
let sdk = PowdrSdk::new(app_config).unwrap();
4946
// Build owned vm instance, so we can mutate it later
5047
let vm_builder = sdk.app_vm_builder().clone();
5148
let vm_pk = sdk.app_pk().app_vm_pk.clone();

0 commit comments

Comments
 (0)