Skip to content

Commit a065c11

Browse files
committed
chore: minor bugs
1 parent dc4ce56 commit a065c11

File tree

1 file changed

+101
-56
lines changed

1 file changed

+101
-56
lines changed

extensions/memcpy/circuit/src/iteration.rs

Lines changed: 101 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -376,9 +376,8 @@ impl<'a> CustomBorrow<'a, MemcpyIterRecordMut<'a>, MemcpyIterLayout> for [u8] {
376376

377377
unsafe fn extract_layout(&self) -> MemcpyIterLayout {
378378
let header: &MemcpyIterRecordHeader = self.borrow();
379-
MultiRowLayout::new(MemcpyIterMetadata {
380-
num_rows: ((header.len - header.shift as u32) >> 4) as usize + 1,
381-
})
379+
let num_rows = ((header.len - header.shift as u32) >> 4) as usize + 1;
380+
MultiRowLayout::new(MemcpyIterMetadata { num_rows })
382381
}
383382
}
384383

@@ -442,10 +441,21 @@ where
442441
);
443442
let mut len = read_rv32_register(state.memory.data(), A2_REGISTER_PTR as u32);
444443

445-
// Create a record with var_size = ((len - shift) >> 4) + 1 which is the number of rows in iteration trace
446-
let record = state.ctx.alloc(MultiRowLayout::new(MemcpyIterMetadata {
447-
num_rows: ((len - shift as u32) >> 4) as usize + 1,
448-
}));
444+
// Create a record sized to the exact number of 16-byte iterations (header + iterations)
445+
// This calculation must match extract_layout and fill_trace
446+
447+
// FIX 1: prevent underflow when len < shift
448+
let effective_len = len.saturating_sub(shift as u32);
449+
let num_iters = (effective_len >> 4) as usize;
450+
eprintln!("num_iters = {:?}", num_iters);
451+
// eprintln!("state.pc = {:?}", state.pc);
452+
// eprintln!("state.memory.timestamp = {:?}", state.memory.timestamp);
453+
// eprintln!("state.memory.data() = {:?}", state.memory.data());
454+
let record: MemcpyIterRecordMut<'_> =
455+
state.ctx.alloc(MultiRowLayout::new(MemcpyIterMetadata {
456+
//allocating based on number of rows needed
457+
num_rows: num_iters + 1,
458+
})); // is this too big then??
449459

450460
// Store the original values in the record
451461
record.inner.shift = shift;
@@ -454,21 +464,29 @@ where
454464
record.inner.dest = dest;
455465
record.inner.source = source;
456466
record.inner.len = len;
467+
eprintln!(
468+
"shift = {:?}, len = {:?}, source = {:?}, source%16 = {:?}, dest = {:?}, dest%16 = {:?}",
469+
shift, len, source, source % 16, dest, dest % 16
470+
);
457471

458472
// Fill record.var for the first row of iteration trace
473+
// FIX 2: read source-4 (the word ending at s[-1]); zero if out-of-bounds.
459474
if shift != 0 {
460-
source -= 12;
461-
record.var[0].data[3] = tracing_read(
462-
state.memory,
463-
RV32_MEMORY_AS,
464-
source - 4,
465-
&mut record.var[0].read_aux[3].prev_timestamp,
466-
);
467-
};
475+
if source >= 4 {
476+
record.var[0].data[3] = tracing_read(
477+
state.memory,
478+
RV32_MEMORY_AS,
479+
source - 4, // correct seed for mixing
480+
&mut record.var[0].read_aux[3].prev_timestamp,
481+
);
482+
} else {
483+
record.var[0].data[3] = [0; 4];
484+
}
485+
}
468486

469487
// Fill record.var for the rest of the rows of iteration trace
470488
let mut idx = 1;
471-
while len - shift as u32 > 15 {
489+
for _ in 0..num_iters {
472490
let writes_data: [[u8; MEMCPY_LOOP_NUM_LIMBS]; 4] = array::from_fn(|i| {
473491
record.var[idx].data[i] = tracing_read(
474492
state.memory,
@@ -477,12 +495,17 @@ where
477495
&mut record.var[idx].read_aux[i].prev_timestamp,
478496
);
479497
let write_data: [u8; MEMCPY_LOOP_NUM_LIMBS] = array::from_fn(|j| {
480-
if j < 4 - shift as usize {
481-
record.var[idx].data[i][j + shift as usize]
482-
} else if i > 0 {
483-
record.var[idx].data[i - 1][j - (4 - shift as usize)]
498+
if j < shift as usize {
499+
if i > 0 {
500+
// First s bytes come from previous 4-byte word tail
501+
record.var[idx].data[i - 1][j + (4 - shift as usize)]
502+
} else {
503+
// For i == 0, take from previous chunk's last word tail
504+
record.var[idx - 1].data[3][j + (4 - shift as usize)]
505+
}
484506
} else {
485-
record.var[idx - 1].data[3][j - (4 - shift as usize)]
507+
// Remaining 4 - s bytes come from current word head
508+
record.var[idx].data[i][j - shift as usize]
486509
}
487510
});
488511
write_data
@@ -503,11 +526,6 @@ where
503526
idx += 1;
504527
}
505528

506-
// Handle the core loop
507-
if shift != 0 {
508-
source += 12;
509-
}
510-
511529
let mut dest_data = [0; 4];
512530
let mut source_data = [0; 4];
513531
let mut len_data = [0; 4];
@@ -598,7 +616,7 @@ impl<F: PrimeField32> TraceFiller<F> for MemcpyIterFiller {
598616
.par_iter_mut()
599617
.zip(sizes.par_iter())
600618
.enumerate()
601-
.for_each(|(row_idx, (chunk, &num_rows))| {
619+
.for_each(|(_row_idx, (chunk, &num_rows))| {
602620
let record: MemcpyIterRecordMut = unsafe {
603621
get_record_from_slice(
604622
chunk,
@@ -607,7 +625,7 @@ impl<F: PrimeField32> TraceFiller<F> for MemcpyIterFiller {
607625
};
608626

609627
tracing::info!("shift: {:?}", record.inner.shift);
610-
// Fill memcpy loop record
628+
611629
self.memcpy_loop_chip.add_new_loop(
612630
mem_helper,
613631
record.inner.from_pc,
@@ -633,8 +651,8 @@ impl<F: PrimeField32> TraceFiller<F> for MemcpyIterFiller {
633651
};
634652

635653
let mut dest = record.inner.dest + ((num_rows - 1) << 4) as u32;
636-
let mut source = record.inner.source + ((num_rows - 1) << 4) as u32
637-
- 12 * (record.inner.shift != 0) as u32;
654+
let mut source = (record.inner.source + ((num_rows - 1) << 4) as u32)
655+
.saturating_sub(12 * (record.inner.shift != 0) as u32);
638656
let mut len =
639657
record.inner.len - ((num_rows - 1) << 4) as u32 - record.inner.shift as u32;
640658

@@ -737,8 +755,8 @@ impl<F: PrimeField32> TraceFiller<F> for MemcpyIterFiller {
737755
cols.dest = F::from_canonical_u32(dest);
738756
cols.timestamp = F::from_canonical_u32(get_timestamp(false));
739757

740-
dest -= 16;
741-
source -= 16;
758+
dest = dest.saturating_sub(16);
759+
source = source.saturating_sub(16);
742760
len += 16;
743761

744762
// if row_idx == 0 && is_start {
@@ -914,26 +932,24 @@ impl<F: PrimeField32> MeteredExecutor<F> for MemcpyIterExecutor {
914932
#[inline(always)]
915933
unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
916934
pre_compute: &MemcpyIterPreCompute,
917-
instret: &mut u64,
918-
pc: &mut u32,
919-
exec_state: &mut VmExecState<F, GuestMemory, CTX>,
935+
vm_state: &mut VmExecState<F, GuestMemory, CTX>,
920936
) -> u32 {
921937
let shift = pre_compute.c;
922938
let mut height = 1;
923939
// Read dest and source from registers
924940
let (dest, source) = if shift == 0 {
925941
(
926-
exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, A3_REGISTER_PTR as u32),
927-
exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, A4_REGISTER_PTR as u32),
942+
vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, A3_REGISTER_PTR as u32),
943+
vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, A4_REGISTER_PTR as u32),
928944
)
929945
} else {
930946
(
931-
exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, A1_REGISTER_PTR as u32),
932-
exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, A3_REGISTER_PTR as u32),
947+
vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, A1_REGISTER_PTR as u32),
948+
vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, A3_REGISTER_PTR as u32),
933949
)
934950
};
935951
// Read length from a2 register
936-
let len = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, A2_REGISTER_PTR as u32);
952+
let len = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, A2_REGISTER_PTR as u32);
937953

938954
let mut dest = u32::from_le_bytes(dest);
939955
let mut source = u32::from_le_bytes(source) - 12 * (shift != 0) as u32;
@@ -950,24 +966,26 @@ unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
950966
debug_assert!(to_dest <= source || to_source <= dest);
951967

952968
// Read the previous data from memory if shift != 0
969+
// Note: when shift != 0, `source` has been adjusted by -12 to align reads,
970+
// so the previous word is at original_source - 4 == (source + 12) - 4 == source + 8.
953971
let mut prev_data = if shift == 0 {
954972
[0; 4]
955973
} else {
956-
exec_state.vm_read::<u8, 4>(RV32_MEMORY_AS, source - 4)
974+
vm_state.vm_read::<u8, 4>(RV32_MEMORY_AS, source - 4)
957975
};
958976

959977
// Run iterations
960978
while len - shift as u32 > 15 {
961979
for i in 0..4 {
962-
let data = exec_state.vm_read::<u8, 4>(RV32_MEMORY_AS, source + 4 * i);
980+
let data = vm_state.vm_read::<u8, 4>(RV32_MEMORY_AS, source + 4 * i);
963981
let write_data: [u8; 4] = array::from_fn(|i| {
964982
if i < 4 - shift as usize {
965983
data[i + shift as usize]
966984
} else {
967985
prev_data[i - (4 - shift as usize)]
968986
}
969987
});
970-
exec_state.vm_write(RV32_MEMORY_AS, dest + 4 * i, &write_data);
988+
vm_state.vm_write(RV32_MEMORY_AS, dest + 4 * i, &write_data);
971989
prev_data = data;
972990
}
973991
len -= 16;
@@ -976,59 +994,86 @@ unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
976994
height += 1;
977995
}
978996

997+
// Handle remaining bytes (len in [0, 15]) so that total `copy_len == original len`.
998+
if len > 0 {
999+
let remaining_words = ((len + 3) >> 2) as usize;
1000+
for i in 0..remaining_words.min(4) {
1001+
let data = vm_state.vm_read::<u8, 4>(RV32_MEMORY_AS, source + 4 * i as u32);
1002+
let write_data: [u8; 4] = array::from_fn(|j| {
1003+
if j < shift as usize {
1004+
prev_data[j + (4 - shift as usize)]
1005+
} else {
1006+
data[j - shift as usize]
1007+
}
1008+
});
1009+
vm_state.vm_write(RV32_MEMORY_AS, dest + 4 * i as u32, &write_data);
1010+
prev_data = data;
1011+
}
1012+
// Advance pointers to reflect bytes written
1013+
let advanced = (remaining_words as u32) << 2;
1014+
source += advanced;
1015+
dest += advanced;
1016+
len = 0;
1017+
height += 1;
1018+
}
1019+
9791020
// Write the result back to memory
9801021
if shift == 0 {
981-
exec_state.vm_write(
1022+
vm_state.vm_write(
9821023
RV32_REGISTER_AS,
9831024
A3_REGISTER_PTR as u32,
9841025
&dest.to_le_bytes(),
9851026
);
986-
exec_state.vm_write(
1027+
vm_state.vm_write(
9871028
RV32_REGISTER_AS,
9881029
A4_REGISTER_PTR as u32,
9891030
&source.to_le_bytes(),
9901031
);
9911032
} else {
9921033
source += 12;
993-
exec_state.vm_write(
1034+
vm_state.vm_write(
9941035
RV32_REGISTER_AS,
9951036
A1_REGISTER_PTR as u32,
9961037
&dest.to_le_bytes(),
9971038
);
998-
exec_state.vm_write(
1039+
vm_state.vm_write(
9991040
RV32_REGISTER_AS,
10001041
A3_REGISTER_PTR as u32,
10011042
&source.to_le_bytes(),
10021043
);
10031044
};
1004-
exec_state.vm_write(RV32_REGISTER_AS, A2_REGISTER_PTR as u32, &len.to_le_bytes());
1045+
vm_state.vm_write(RV32_REGISTER_AS, A2_REGISTER_PTR as u32, &len.to_le_bytes());
10051046

1006-
*pc = pc.wrapping_add(DEFAULT_PC_STEP);
1007-
*instret += 1;
1047+
*vm_state.pc_mut() = vm_state.pc().wrapping_add(DEFAULT_PC_STEP);
1048+
*vm_state.instret_mut() = vm_state.instret() + 1;
10081049
height
10091050
}
10101051

10111052
unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
10121053
pre_compute: &[u8],
10131054
instret: &mut u64,
10141055
pc: &mut u32,
1015-
_instret_end: u64,
1016-
exec_state: &mut VmExecState<F, GuestMemory, CTX>,
1056+
_arg: u64,
1057+
vm_state: &mut VmExecState<F, GuestMemory, CTX>,
10171058
) {
10181059
let pre_compute: &MemcpyIterPreCompute = pre_compute.borrow();
1019-
execute_e12_impl::<F, CTX>(pre_compute, instret, pc, exec_state);
1060+
let height = execute_e12_impl::<F, CTX>(pre_compute, vm_state);
1061+
*instret += height as u64;
1062+
*pc = vm_state.pc();
10201063
}
10211064

10221065
unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait>(
10231066
pre_compute: &[u8],
10241067
instret: &mut u64,
10251068
pc: &mut u32,
10261069
_arg: u64,
1027-
exec_state: &mut VmExecState<F, GuestMemory, CTX>,
1070+
vm_state: &mut VmExecState<F, GuestMemory, CTX>,
10281071
) {
10291072
let pre_compute: &E2PreCompute<MemcpyIterPreCompute> = pre_compute.borrow();
1030-
let height = execute_e12_impl::<F, CTX>(&pre_compute.data, instret, pc, exec_state);
1031-
exec_state
1073+
let height = execute_e12_impl::<F, CTX>(&pre_compute.data, vm_state);
1074+
*instret += height as u64;
1075+
*pc = vm_state.pc();
1076+
vm_state
10321077
.ctx
10331078
.on_height_change(pre_compute.chip_idx as usize, height);
10341079
}

0 commit comments

Comments
 (0)