|
| 1 | +use super::helpers::are_types_jvm_compatible; |
| 2 | +use crate::oomir::{self, Type}; |
| 3 | +use ristretto_classfile::{ |
| 4 | + self as jvm, ConstantPool, |
| 5 | + attributes::{ArrayType, Instruction}, |
| 6 | +}; |
| 7 | + |
| 8 | +// Helper to get the appropriate integer constant loading instruction |
| 9 | +pub fn get_int_const_instr(cp: &mut ConstantPool, val: i32) -> Instruction { |
| 10 | + match val { |
| 11 | + // Direct iconst mapping |
| 12 | + -1 => Instruction::Iconst_m1, |
| 13 | + 0 => Instruction::Iconst_0, |
| 14 | + 1 => Instruction::Iconst_1, |
| 15 | + 2 => Instruction::Iconst_2, |
| 16 | + 3 => Instruction::Iconst_3, |
| 17 | + 4 => Instruction::Iconst_4, |
| 18 | + 5 => Instruction::Iconst_5, |
| 19 | + |
| 20 | + // Bipush range (-128 to 127), excluding the iconst values already handled |
| 21 | + v @ -128..=-2 | v @ 6..=127 => Instruction::Bipush(v as i8), |
| 22 | + |
| 23 | + // Sipush range (-32768 to 32767), excluding the bipush range |
| 24 | + v @ -32768..=-129 | v @ 128..=32767 => Instruction::Sipush(v as i16), |
| 25 | + |
| 26 | + // Use LDC for values outside the -32768 to 32767 range |
| 27 | + v => { |
| 28 | + let index = cp |
| 29 | + .add_integer(v) |
| 30 | + .expect("Failed to add integer to constant pool"); |
| 31 | + if let Ok(idx8) = u8::try_from(index) { |
| 32 | + Instruction::Ldc(idx8) |
| 33 | + } else { |
| 34 | + Instruction::Ldc_w(index) |
| 35 | + } |
| 36 | + } |
| 37 | + } |
| 38 | +} |
| 39 | + |
| 40 | +// Helper to get the appropriate long constant loading instruction |
| 41 | +pub fn get_long_const_instr(cp: &mut ConstantPool, val: i64) -> Instruction { |
| 42 | + // <-- Add `cp: &mut ConstantPool` |
| 43 | + match val { |
| 44 | + 0 => Instruction::Lconst_0, |
| 45 | + 1 => Instruction::Lconst_1, |
| 46 | + _ => { |
| 47 | + // Add the long value to the constant pool. |
| 48 | + let index = cp |
| 49 | + .add_long(val) |
| 50 | + .expect("Failed to add long to constant pool"); |
| 51 | + // Ldc2_w is used for long/double constants and always takes a u16 index. |
| 52 | + Instruction::Ldc2_w(index) |
| 53 | + } |
| 54 | + } |
| 55 | +} |
| 56 | + |
| 57 | +// Helper to get the appropriate float constant loading instruction |
| 58 | +pub fn get_float_const_instr(cp: &mut ConstantPool, val: f32) -> Instruction { |
| 59 | + if val == 0.0 { |
| 60 | + Instruction::Fconst_0 |
| 61 | + } else if val == 1.0 { |
| 62 | + Instruction::Fconst_1 |
| 63 | + } else if val == 2.0 { |
| 64 | + Instruction::Fconst_2 |
| 65 | + } else { |
| 66 | + // Add the float value to the constant pool. |
| 67 | + let index = cp |
| 68 | + .add_float(val) |
| 69 | + .expect("Failed to add float to constant pool"); |
| 70 | + // Ldc2_w is used for long/double constants and always takes a u16 index. |
| 71 | + Instruction::Ldc_w(index) |
| 72 | + } |
| 73 | +} |
| 74 | + |
| 75 | +// Helper to get the appropriate double constant loading instruction |
| 76 | +pub fn get_double_const_instr(cp: &mut ConstantPool, val: f64) -> Instruction { |
| 77 | + // Using bit representation for exact zero comparison is more robust |
| 78 | + if val.to_bits() == 0.0f64.to_bits() { |
| 79 | + // Handles +0.0 and -0.0 |
| 80 | + Instruction::Dconst_0 |
| 81 | + } else if val == 1.0 { |
| 82 | + Instruction::Dconst_1 |
| 83 | + } else { |
| 84 | + // Add the double value to the constant pool. |
| 85 | + let index = cp |
| 86 | + .add_double(val) |
| 87 | + .expect("Failed to add double to constant pool"); |
| 88 | + // Ldc2_w is used for long/double constants and always takes a u16 index. |
| 89 | + Instruction::Ldc2_w(index) |
| 90 | + } |
| 91 | +} |
| 92 | + |
| 93 | +/// Appends JVM instructions for loading a constant onto the stack. |
| 94 | +pub fn load_constant( |
| 95 | + instructions: &mut Vec<Instruction>, |
| 96 | + cp: &mut ConstantPool, |
| 97 | + constant: &oomir::Constant, |
| 98 | +) -> Result<(), jvm::Error> { |
| 99 | + use jvm::attributes::Instruction as JI; |
| 100 | + use oomir::Constant as OC; |
| 101 | + |
| 102 | + let mut instructions_to_add = Vec::new(); |
| 103 | + |
| 104 | + match constant { |
| 105 | + OC::I8(v) => instructions_to_add.push(get_int_const_instr(cp, *v as i32)), |
| 106 | + OC::I16(v) => instructions_to_add.push(get_int_const_instr(cp, *v as i32)), |
| 107 | + OC::I32(v) => instructions_to_add.push(get_int_const_instr(cp, *v)), |
| 108 | + OC::I64(v) => instructions_to_add.push(get_long_const_instr(cp, *v)), |
| 109 | + OC::F32(v) => instructions_to_add.push(get_float_const_instr(cp, *v)), |
| 110 | + OC::F64(v) => instructions_to_add.push(get_double_const_instr(cp, *v)), |
| 111 | + OC::Boolean(v) => instructions_to_add.push(if *v { JI::Iconst_1 } else { JI::Iconst_0 }), |
| 112 | + OC::Char(v) => instructions_to_add.push(get_int_const_instr(cp, *v as i32)), |
| 113 | + OC::String(s) => { |
| 114 | + let index = cp.add_string(s)?; |
| 115 | + instructions_to_add.push(if let Ok(idx8) = u8::try_from(index) { |
| 116 | + JI::Ldc(idx8) |
| 117 | + } else { |
| 118 | + JI::Ldc_w(index) |
| 119 | + }); |
| 120 | + } |
| 121 | + OC::Class(c) => { |
| 122 | + let index = cp.add_class(c)?; |
| 123 | + instructions_to_add.push(if let Ok(idx8) = u8::try_from(index) { |
| 124 | + JI::Ldc(idx8) |
| 125 | + } else { |
| 126 | + JI::Ldc_w(index) |
| 127 | + }); |
| 128 | + } |
| 129 | + OC::Array(elem_ty, elements) => { |
| 130 | + let array_len = elements.len(); |
| 131 | + |
| 132 | + // 1. Push array size onto stack |
| 133 | + instructions_to_add.push(get_int_const_instr(cp, array_len as i32)); |
| 134 | + |
| 135 | + // 2. Create the new array (primitive or reference) |
| 136 | + if let Some(atype_code) = elem_ty.to_jvm_primitive_array_type_code() { |
| 137 | + let array_type = ArrayType::from_bytes(&mut std::io::Cursor::new(vec![atype_code])) // Wrap atype_code in Cursor<Vec<u8>> |
| 138 | + .map_err(|e| jvm::Error::VerificationError { |
| 139 | + context: format!("Attempting to load constant {:?}", constant), // Use Display formatting for the error type if available |
| 140 | + message: format!( |
| 141 | + "Invalid primitive array type code {}: {:?}", |
| 142 | + atype_code, e |
| 143 | + ), |
| 144 | + })?; |
| 145 | + instructions_to_add.push(JI::Newarray(array_type)); // Stack: [arrayref] |
| 146 | + } else if let Some(internal_name) = elem_ty.to_jvm_internal_name() { |
| 147 | + let class_index = cp.add_class(&internal_name)?; |
| 148 | + instructions_to_add.push(JI::Anewarray(class_index)); // Stack: [arrayref] |
| 149 | + } else { |
| 150 | + return Err(jvm::Error::VerificationError { |
| 151 | + context: format!("Attempting to load constant {:?}", constant), |
| 152 | + message: format!("Cannot create JVM array for element type: {:?}", elem_ty), |
| 153 | + }); |
| 154 | + } |
| 155 | + |
| 156 | + let store_instruction = elem_ty.get_jvm_array_store_instruction().ok_or_else(|| { |
| 157 | + jvm::Error::VerificationError { |
| 158 | + context: format!("Attempting to load constant {:?}", constant), |
| 159 | + message: format!( |
| 160 | + "Cannot determine array store instruction for type: {:?}", |
| 161 | + elem_ty |
| 162 | + ), |
| 163 | + } |
| 164 | + })?; |
| 165 | + |
| 166 | + // 3. Populate the array |
| 167 | + for (i, element_const) in elements.iter().enumerate() { |
| 168 | + let constant_type = Type::from_constant(element_const); |
| 169 | + if &constant_type != elem_ty.as_ref() |
| 170 | + && !are_types_jvm_compatible(&constant_type, elem_ty) |
| 171 | + { |
| 172 | + return Err(jvm::Error::VerificationError { |
| 173 | + context: format!("Attempting to load constant {:?}", constant), |
| 174 | + message: format!( |
| 175 | + "Type mismatch in Constant::Array: expected {:?}, found {:?} for element {}", |
| 176 | + elem_ty, constant_type, i |
| 177 | + ), |
| 178 | + }); |
| 179 | + } |
| 180 | + |
| 181 | + instructions_to_add.push(JI::Dup); // Stack: [arrayref, arrayref] |
| 182 | + instructions_to_add.push(get_int_const_instr(cp, i as i32)); // Stack: [arrayref, arrayref, index] |
| 183 | + |
| 184 | + // --- Corrected Element Loading --- |
| 185 | + // 1. Record the length of the main instruction vector *before* the recursive call. |
| 186 | + let original_jvm_len = instructions.len(); |
| 187 | + |
| 188 | + // 2. Make the recursive call. This *will* append instructions to instructions. |
| 189 | + load_constant(instructions, cp, element_const)?; |
| 190 | + |
| 191 | + // 3. Determine the range of instructions added by the recursive call. |
| 192 | + let new_jvm_len = instructions.len(); |
| 193 | + |
| 194 | + // 4. If instructions were added, copy them from instructions to instructions_to_add. |
| 195 | + if new_jvm_len > original_jvm_len { |
| 196 | + // Create a slice referencing the newly added instructions |
| 197 | + let added_instructions_slice = &instructions[original_jvm_len..new_jvm_len]; |
| 198 | + // Extend the temporary vector with a clone of these instructions |
| 199 | + instructions_to_add.extend_from_slice(added_instructions_slice); |
| 200 | + } |
| 201 | + |
| 202 | + // 5. Remove the instructions just added by the recursive call from instructions. |
| 203 | + // We truncate back to the length it had *before* the recursive call. |
| 204 | + instructions.truncate(original_jvm_len); |
| 205 | + // Now, instructions is back to its state before loading the element, |
| 206 | + // and instructions_to_add contains the necessary Dup, index, element load instructions. |
| 207 | + |
| 208 | + // Add the array store instruction to the temporary vector |
| 209 | + instructions_to_add.push(store_instruction.clone()); // Stack: [arrayref] |
| 210 | + } |
| 211 | + // Final stack state after loop: [arrayref] (the populated array) |
| 212 | + } |
| 213 | + }; |
| 214 | + |
| 215 | + // Append the generated instructions for this constant (now including array logic) |
| 216 | + instructions.extend(instructions_to_add); |
| 217 | + |
| 218 | + Ok(()) |
| 219 | +} |
0 commit comments