Skip to content

Commit c4babb4

Browse files
committed
Refactor of DataType handling in lower1 and lower2 to provide early support for enums
1 parent 56e767f commit c4babb4

File tree

7 files changed

+486
-292
lines changed

7 files changed

+486
-292
lines changed

src/lower1/control_flow/rvalue.rs

Lines changed: 109 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -543,13 +543,120 @@ pub fn convert_rvalue_to_operand<'a>(
543543
data_types.insert(
544544
jvm_class_name.clone(),
545545
oomir::DataType {
546-
name: jvm_class_name,
547546
fields: oomir_fields,
547+
is_abstract: false,
548+
methods: HashMap::new(),
549+
super_class: None,
548550
},
549551
);
550552
}
553+
} else if adt_def.is_enum() {
554+
let variant_def = adt_def.variant(*variant_idx);
555+
let base_enum_name = make_jvm_safe(&tcx.def_path_str(adt_def.did()));
556+
let variant_class_name = format!(
557+
"{}${}",
558+
base_enum_name,
559+
make_jvm_safe(&variant_def.name.to_string())
560+
);
561+
562+
println!(
563+
"Info: Handling Enum Aggregate (Variant: {}) -> Temp Var '{}' (Class: {})",
564+
variant_def.name, temp_aggregate_var, variant_class_name
565+
);
566+
567+
/*
568+
i.e. consider rust code:
569+
```rust
570+
enum MyEnum {
571+
A(i32),
572+
B{x: String},
573+
C,
574+
}
575+
```
576+
577+
psuedo-java for the plan on how to handle this
578+
```java
579+
abstract class MyEnum {
580+
public abstract int getVariantIdx();
581+
}
582+
583+
class MyEnum$A extends MyEnum {
584+
public int field0;
585+
586+
public MyEnum$A(int field0) {
587+
this.field0 = field0;
588+
}
589+
590+
@Override
591+
public final int getVariantIdx() { return 0; }
592+
}
593+
594+
class MyEnum$B extends MyEnum {
595+
public String field0;
596+
597+
public MyEnum$B(String field0) {
598+
this.field0 = field0;
599+
}
600+
601+
@Override
602+
public final int getVariantIdx() { return 1; }
603+
}
604+
605+
class MyEnum$C extends MyEnum {
606+
@Override
607+
public final int getVariantIdx() { return 2; }
608+
}
609+
```
610+
*/
611+
612+
// the enum in general
613+
if !data_types.contains_key(&base_enum_name) {
614+
let mut methods = HashMap::new();
615+
methods.insert("getVariantIdx".to_string(), (oomir::Type::I32, None));
616+
data_types.insert(
617+
base_enum_name.clone(),
618+
oomir::DataType {
619+
fields: vec![], // No fields in the abstract class
620+
is_abstract: true,
621+
methods,
622+
super_class: None,
623+
},
624+
);
625+
}
626+
627+
// this variant
628+
if !data_types.contains_key(&variant_class_name) {
629+
let mut fields = vec![];
630+
for (i, field) in variant_def.fields.iter().enumerate() {
631+
let field_name = format!("field{}", i);
632+
let field_type =
633+
ty_to_oomir_type(field.ty(tcx, substs), tcx, data_types);
634+
fields.push((field_name, field_type));
635+
}
636+
637+
let mut methods = HashMap::new();
638+
methods.insert(
639+
"getVariantIdx".to_string(),
640+
(
641+
oomir::Type::I32,
642+
Some(oomir::Constant::I32(variant_idx.as_u32() as i32)),
643+
),
644+
);
645+
646+
data_types.insert(
647+
variant_class_name.clone(),
648+
oomir::DataType {
649+
fields,
650+
is_abstract: false,
651+
methods,
652+
super_class: Some(base_enum_name.clone()),
653+
},
654+
);
655+
}
656+
657+
// Construct the enum variant object
551658
} else {
552-
// Enum, Union
659+
// Union
553660
println!("Warning: Unhandled ADT Aggregate Kind -> Temp Placeholder");
554661
// make a placeholder (Class("java/lang/Object"))
555662
instructions.push(oomir::Instruction::ConstructObject {

src/lower1/types.rs

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,27 @@ pub fn ty_to_oomir_type<'tcx>(
6161
data_types.insert(
6262
jvm_name.clone(),
6363
oomir::DataType {
64-
name: jvm_name.clone(),
6564
fields: oomir_fields,
65+
is_abstract: false,
66+
methods: HashMap::new(),
67+
super_class: None,
6668
},
6769
);
70+
} else if adt_def.is_enum() {
71+
// the enum in general
72+
if !data_types.contains_key(&jvm_name) {
73+
let mut methods = HashMap::new();
74+
methods.insert("getVariantIdx".to_string(), (oomir::Type::I32, None));
75+
data_types.insert(
76+
jvm_name.clone(),
77+
oomir::DataType {
78+
fields: vec![], // No fields in the abstract class
79+
is_abstract: true,
80+
methods,
81+
super_class: None,
82+
},
83+
);
84+
}
6885
}
6986
oomir::Type::Class(jvm_name)
7087
}
@@ -119,8 +136,10 @@ pub fn ty_to_oomir_type<'tcx>(
119136

120137
// Create and insert the DataType definition
121138
let tuple_data_type = oomir::DataType {
122-
name: tuple_class_name.clone(),
123139
fields: oomir_fields,
140+
is_abstract: false,
141+
methods: HashMap::new(),
142+
super_class: None,
124143
};
125144
data_types.insert(tuple_class_name.clone(), tuple_data_type);
126145
println!(" -> Added DataType: {:?}", data_types[&tuple_class_name]);

src/lower2.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use ristretto_classfile::{
1313
use rustc_middle::ty::TyCtxt;
1414
use std::collections::HashMap;
1515

16+
mod consts;
1617
mod helpers;
1718
mod jvm_gen;
1819
mod shim;

src/lower2/consts.rs

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
use super::helpers::are_types_jvm_compatible;
2+
use crate::oomir::{self, Type};
3+
use ristretto_classfile::{
4+
self as jvm, ConstantPool,
5+
attributes::{ArrayType, Instruction},
6+
};
7+
8+
// Helper to get the appropriate integer constant loading instruction
9+
pub fn get_int_const_instr(cp: &mut ConstantPool, val: i32) -> Instruction {
10+
match val {
11+
// Direct iconst mapping
12+
-1 => Instruction::Iconst_m1,
13+
0 => Instruction::Iconst_0,
14+
1 => Instruction::Iconst_1,
15+
2 => Instruction::Iconst_2,
16+
3 => Instruction::Iconst_3,
17+
4 => Instruction::Iconst_4,
18+
5 => Instruction::Iconst_5,
19+
20+
// Bipush range (-128 to 127), excluding the iconst values already handled
21+
v @ -128..=-2 | v @ 6..=127 => Instruction::Bipush(v as i8),
22+
23+
// Sipush range (-32768 to 32767), excluding the bipush range
24+
v @ -32768..=-129 | v @ 128..=32767 => Instruction::Sipush(v as i16),
25+
26+
// Use LDC for values outside the -32768 to 32767 range
27+
v => {
28+
let index = cp
29+
.add_integer(v)
30+
.expect("Failed to add integer to constant pool");
31+
if let Ok(idx8) = u8::try_from(index) {
32+
Instruction::Ldc(idx8)
33+
} else {
34+
Instruction::Ldc_w(index)
35+
}
36+
}
37+
}
38+
}
39+
40+
// Helper to get the appropriate long constant loading instruction
41+
pub fn get_long_const_instr(cp: &mut ConstantPool, val: i64) -> Instruction {
42+
// <-- Add `cp: &mut ConstantPool`
43+
match val {
44+
0 => Instruction::Lconst_0,
45+
1 => Instruction::Lconst_1,
46+
_ => {
47+
// Add the long value to the constant pool.
48+
let index = cp
49+
.add_long(val)
50+
.expect("Failed to add long to constant pool");
51+
// Ldc2_w is used for long/double constants and always takes a u16 index.
52+
Instruction::Ldc2_w(index)
53+
}
54+
}
55+
}
56+
57+
// Helper to get the appropriate float constant loading instruction
58+
pub fn get_float_const_instr(cp: &mut ConstantPool, val: f32) -> Instruction {
59+
if val == 0.0 {
60+
Instruction::Fconst_0
61+
} else if val == 1.0 {
62+
Instruction::Fconst_1
63+
} else if val == 2.0 {
64+
Instruction::Fconst_2
65+
} else {
66+
// Add the float value to the constant pool.
67+
let index = cp
68+
.add_float(val)
69+
.expect("Failed to add float to constant pool");
70+
// Ldc2_w is used for long/double constants and always takes a u16 index.
71+
Instruction::Ldc_w(index)
72+
}
73+
}
74+
75+
// Helper to get the appropriate double constant loading instruction
76+
pub fn get_double_const_instr(cp: &mut ConstantPool, val: f64) -> Instruction {
77+
// Using bit representation for exact zero comparison is more robust
78+
if val.to_bits() == 0.0f64.to_bits() {
79+
// Handles +0.0 and -0.0
80+
Instruction::Dconst_0
81+
} else if val == 1.0 {
82+
Instruction::Dconst_1
83+
} else {
84+
// Add the double value to the constant pool.
85+
let index = cp
86+
.add_double(val)
87+
.expect("Failed to add double to constant pool");
88+
// Ldc2_w is used for long/double constants and always takes a u16 index.
89+
Instruction::Ldc2_w(index)
90+
}
91+
}
92+
93+
/// Appends JVM instructions for loading a constant onto the stack.
94+
pub fn load_constant(
95+
instructions: &mut Vec<Instruction>,
96+
cp: &mut ConstantPool,
97+
constant: &oomir::Constant,
98+
) -> Result<(), jvm::Error> {
99+
use jvm::attributes::Instruction as JI;
100+
use oomir::Constant as OC;
101+
102+
let mut instructions_to_add = Vec::new();
103+
104+
match constant {
105+
OC::I8(v) => instructions_to_add.push(get_int_const_instr(cp, *v as i32)),
106+
OC::I16(v) => instructions_to_add.push(get_int_const_instr(cp, *v as i32)),
107+
OC::I32(v) => instructions_to_add.push(get_int_const_instr(cp, *v)),
108+
OC::I64(v) => instructions_to_add.push(get_long_const_instr(cp, *v)),
109+
OC::F32(v) => instructions_to_add.push(get_float_const_instr(cp, *v)),
110+
OC::F64(v) => instructions_to_add.push(get_double_const_instr(cp, *v)),
111+
OC::Boolean(v) => instructions_to_add.push(if *v { JI::Iconst_1 } else { JI::Iconst_0 }),
112+
OC::Char(v) => instructions_to_add.push(get_int_const_instr(cp, *v as i32)),
113+
OC::String(s) => {
114+
let index = cp.add_string(s)?;
115+
instructions_to_add.push(if let Ok(idx8) = u8::try_from(index) {
116+
JI::Ldc(idx8)
117+
} else {
118+
JI::Ldc_w(index)
119+
});
120+
}
121+
OC::Class(c) => {
122+
let index = cp.add_class(c)?;
123+
instructions_to_add.push(if let Ok(idx8) = u8::try_from(index) {
124+
JI::Ldc(idx8)
125+
} else {
126+
JI::Ldc_w(index)
127+
});
128+
}
129+
OC::Array(elem_ty, elements) => {
130+
let array_len = elements.len();
131+
132+
// 1. Push array size onto stack
133+
instructions_to_add.push(get_int_const_instr(cp, array_len as i32));
134+
135+
// 2. Create the new array (primitive or reference)
136+
if let Some(atype_code) = elem_ty.to_jvm_primitive_array_type_code() {
137+
let array_type = ArrayType::from_bytes(&mut std::io::Cursor::new(vec![atype_code])) // Wrap atype_code in Cursor<Vec<u8>>
138+
.map_err(|e| jvm::Error::VerificationError {
139+
context: format!("Attempting to load constant {:?}", constant), // Use Display formatting for the error type if available
140+
message: format!(
141+
"Invalid primitive array type code {}: {:?}",
142+
atype_code, e
143+
),
144+
})?;
145+
instructions_to_add.push(JI::Newarray(array_type)); // Stack: [arrayref]
146+
} else if let Some(internal_name) = elem_ty.to_jvm_internal_name() {
147+
let class_index = cp.add_class(&internal_name)?;
148+
instructions_to_add.push(JI::Anewarray(class_index)); // Stack: [arrayref]
149+
} else {
150+
return Err(jvm::Error::VerificationError {
151+
context: format!("Attempting to load constant {:?}", constant),
152+
message: format!("Cannot create JVM array for element type: {:?}", elem_ty),
153+
});
154+
}
155+
156+
let store_instruction = elem_ty.get_jvm_array_store_instruction().ok_or_else(|| {
157+
jvm::Error::VerificationError {
158+
context: format!("Attempting to load constant {:?}", constant),
159+
message: format!(
160+
"Cannot determine array store instruction for type: {:?}",
161+
elem_ty
162+
),
163+
}
164+
})?;
165+
166+
// 3. Populate the array
167+
for (i, element_const) in elements.iter().enumerate() {
168+
let constant_type = Type::from_constant(element_const);
169+
if &constant_type != elem_ty.as_ref()
170+
&& !are_types_jvm_compatible(&constant_type, elem_ty)
171+
{
172+
return Err(jvm::Error::VerificationError {
173+
context: format!("Attempting to load constant {:?}", constant),
174+
message: format!(
175+
"Type mismatch in Constant::Array: expected {:?}, found {:?} for element {}",
176+
elem_ty, constant_type, i
177+
),
178+
});
179+
}
180+
181+
instructions_to_add.push(JI::Dup); // Stack: [arrayref, arrayref]
182+
instructions_to_add.push(get_int_const_instr(cp, i as i32)); // Stack: [arrayref, arrayref, index]
183+
184+
// --- Corrected Element Loading ---
185+
// 1. Record the length of the main instruction vector *before* the recursive call.
186+
let original_jvm_len = instructions.len();
187+
188+
// 2. Make the recursive call. This *will* append instructions to instructions.
189+
load_constant(instructions, cp, element_const)?;
190+
191+
// 3. Determine the range of instructions added by the recursive call.
192+
let new_jvm_len = instructions.len();
193+
194+
// 4. If instructions were added, copy them from instructions to instructions_to_add.
195+
if new_jvm_len > original_jvm_len {
196+
// Create a slice referencing the newly added instructions
197+
let added_instructions_slice = &instructions[original_jvm_len..new_jvm_len];
198+
// Extend the temporary vector with a clone of these instructions
199+
instructions_to_add.extend_from_slice(added_instructions_slice);
200+
}
201+
202+
// 5. Remove the instructions just added by the recursive call from instructions.
203+
// We truncate back to the length it had *before* the recursive call.
204+
instructions.truncate(original_jvm_len);
205+
// Now, instructions is back to its state before loading the element,
206+
// and instructions_to_add contains the necessary Dup, index, element load instructions.
207+
208+
// Add the array store instruction to the temporary vector
209+
instructions_to_add.push(store_instruction.clone()); // Stack: [arrayref]
210+
}
211+
// Final stack state after loop: [arrayref] (the populated array)
212+
}
213+
};
214+
215+
// Append the generated instructions for this constant (now including array logic)
216+
instructions.extend(instructions_to_add);
217+
218+
Ok(())
219+
}

0 commit comments

Comments
 (0)