Skip to content

Commit b6376cb

Browse files
committed
Refactor HashMap handling in typed.rs to use lower_map_iter for improved iteration and memory management. Introduce new implementations for ComponentType, Lower, and Lift traits for std::collections::HashMap, enhancing support for map types in the component model.
1 parent fc4dad3 commit b6376cb

File tree

1 file changed

+180
-10
lines changed
  • crates/wasmtime/src/runtime/component/func

1 file changed

+180
-10
lines changed

crates/wasmtime/src/runtime/component/func/typed.rs

Lines changed: 180 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2138,7 +2138,7 @@ where
21382138
}
21392139
_ => bad_type_info(),
21402140
};
2141-
let (ptr, len) = lower_map(cx, key_ty, value_ty, self)?;
2141+
let (ptr, len) = lower_map_iter(cx, key_ty, value_ty, self.len(), self.iter())?;
21422142
// See "WRITEPTR64" above for why this is always storing a 64-bit
21432143
// integer.
21442144
map_maybe_uninit!(dst[0]).write(ValRaw::i64(ptr as i64));
@@ -2160,36 +2160,36 @@ where
21602160
_ => bad_type_info(),
21612161
};
21622162
debug_assert!(offset % (Self::ALIGN32 as usize) == 0);
2163-
let (ptr, len) = lower_map(cx, key_ty, value_ty, self)?;
2163+
let (ptr, len) = lower_map_iter(cx, key_ty, value_ty, self.len(), self.iter())?;
21642164
*cx.get(offset + 0) = u32::try_from(ptr).unwrap().to_le_bytes();
21652165
*cx.get(offset + 4) = u32::try_from(len).unwrap().to_le_bytes();
21662166
Ok(())
21672167
}
21682168
}
21692169

2170-
fn lower_map<K, V, U>(
2170+
fn lower_map_iter<'a, K, V, U>(
21712171
cx: &mut LowerContext<'_, U>,
21722172
key_ty: InterfaceType,
21732173
value_ty: InterfaceType,
2174-
map: &HashMap<K, V>,
2174+
len: usize,
2175+
iter: impl Iterator<Item = (&'a K, &'a V)>,
21752176
) -> Result<(usize, usize)>
21762177
where
2177-
K: Lower,
2178-
V: Lower,
2178+
K: Lower + 'a,
2179+
V: Lower + 'a,
21792180
{
21802181
// Calculate the tuple layout: each entry is a (key, value) record.
21812182
let tuple_abi = CanonicalAbiInfo::record_static(&[K::ABI, V::ABI]);
21822183
let tuple_size = tuple_abi.size32 as usize;
21832184
let tuple_align = tuple_abi.align32;
21842185

2185-
let size = map
2186-
.len()
2186+
let size = len
21872187
.checked_mul(tuple_size)
21882188
.ok_or_else(|| format_err!("size overflow copying a map"))?;
21892189
let ptr = cx.realloc(0, 0, tuple_align, size)?;
21902190

21912191
let mut entry_offset = ptr;
2192-
for (key, value) in map.iter() {
2192+
for (key, value) in iter {
21932193
// Lower key at the start of the tuple
21942194
let mut field_offset = 0usize;
21952195
let key_field = K::ABI.next_field32_size(&mut field_offset);
@@ -2200,7 +2200,7 @@ where
22002200
entry_offset += tuple_size;
22012201
}
22022202

2203-
Ok((ptr, map.len()))
2203+
Ok((ptr, len))
22042204
}
22052205

22062206
unsafe impl<K, V> Lift for HashMap<K, V>
@@ -2293,6 +2293,176 @@ where
22932293
Ok(result)
22942294
}
22952295

2296+
// =============================================================================
2297+
// std::collections::HashMap<K, V> support for component model `map<K, V>`
2298+
//
2299+
// This mirrors the wasmtime_environ::collections::HashMap implementation above
2300+
// but works with the standard library HashMap type, which is what users will
2301+
// naturally reach for.
2302+
2303+
#[cfg(feature = "std")]
2304+
unsafe impl<K, V> ComponentType for std::collections::HashMap<K, V>
2305+
where
2306+
K: ComponentType,
2307+
V: ComponentType,
2308+
{
2309+
type Lower = [ValRaw; 2];
2310+
2311+
const ABI: CanonicalAbiInfo = CanonicalAbiInfo::POINTER_PAIR;
2312+
2313+
fn typecheck(ty: &InterfaceType, types: &InstanceType<'_>) -> Result<()> {
2314+
match ty {
2315+
InterfaceType::Map(t) => {
2316+
let map_ty = &types.types[*t];
2317+
K::typecheck(&map_ty.key, types)?;
2318+
V::typecheck(&map_ty.value, types)?;
2319+
Ok(())
2320+
}
2321+
other => bail!("expected `map` found `{}`", desc(other)),
2322+
}
2323+
}
2324+
}
2325+
2326+
#[cfg(feature = "std")]
2327+
unsafe impl<K, V> Lower for std::collections::HashMap<K, V>
2328+
where
2329+
K: Lower,
2330+
V: Lower,
2331+
{
2332+
fn linear_lower_to_flat<U>(
2333+
&self,
2334+
cx: &mut LowerContext<'_, U>,
2335+
ty: InterfaceType,
2336+
dst: &mut MaybeUninit<[ValRaw; 2]>,
2337+
) -> Result<()> {
2338+
let (key_ty, value_ty) = match ty {
2339+
InterfaceType::Map(i) => {
2340+
let m = &cx.types[i];
2341+
(m.key, m.value)
2342+
}
2343+
_ => bad_type_info(),
2344+
};
2345+
let (ptr, len) = lower_map_iter(cx, key_ty, value_ty, self.len(), self.iter())?;
2346+
// See "WRITEPTR64" above for why this is always storing a 64-bit
2347+
// integer.
2348+
map_maybe_uninit!(dst[0]).write(ValRaw::i64(ptr as i64));
2349+
map_maybe_uninit!(dst[1]).write(ValRaw::i64(len as i64));
2350+
Ok(())
2351+
}
2352+
2353+
fn linear_lower_to_memory<U>(
2354+
&self,
2355+
cx: &mut LowerContext<'_, U>,
2356+
ty: InterfaceType,
2357+
offset: usize,
2358+
) -> Result<()> {
2359+
let (key_ty, value_ty) = match ty {
2360+
InterfaceType::Map(i) => {
2361+
let m = &cx.types[i];
2362+
(m.key, m.value)
2363+
}
2364+
_ => bad_type_info(),
2365+
};
2366+
debug_assert!(offset % (Self::ALIGN32 as usize) == 0);
2367+
let (ptr, len) = lower_map_iter(cx, key_ty, value_ty, self.len(), self.iter())?;
2368+
*cx.get(offset + 0) = u32::try_from(ptr).unwrap().to_le_bytes();
2369+
*cx.get(offset + 4) = u32::try_from(len).unwrap().to_le_bytes();
2370+
Ok(())
2371+
}
2372+
}
2373+
2374+
#[cfg(feature = "std")]
2375+
unsafe impl<K, V> Lift for std::collections::HashMap<K, V>
2376+
where
2377+
K: Lift + Eq + Hash,
2378+
V: Lift,
2379+
{
2380+
fn linear_lift_from_flat(
2381+
cx: &mut LiftContext<'_>,
2382+
ty: InterfaceType,
2383+
src: &Self::Lower,
2384+
) -> Result<Self> {
2385+
let (key_ty, value_ty) = match ty {
2386+
InterfaceType::Map(i) => {
2387+
let m = &cx.types[i];
2388+
(m.key, m.value)
2389+
}
2390+
_ => bad_type_info(),
2391+
};
2392+
// FIXME(#4311): needs memory64 treatment
2393+
let ptr = src[0].get_u32();
2394+
let len = src[1].get_u32();
2395+
let (ptr, len) = (usize::try_from(ptr)?, usize::try_from(len)?);
2396+
lift_std_map(cx, key_ty, value_ty, ptr, len)
2397+
}
2398+
2399+
fn linear_lift_from_memory(
2400+
cx: &mut LiftContext<'_>,
2401+
ty: InterfaceType,
2402+
bytes: &[u8],
2403+
) -> Result<Self> {
2404+
let (key_ty, value_ty) = match ty {
2405+
InterfaceType::Map(i) => {
2406+
let m = &cx.types[i];
2407+
(m.key, m.value)
2408+
}
2409+
_ => bad_type_info(),
2410+
};
2411+
debug_assert!((bytes.as_ptr() as usize) % (Self::ALIGN32 as usize) == 0);
2412+
// FIXME(#4311): needs memory64 treatment
2413+
let ptr = u32::from_le_bytes(bytes[..4].try_into().unwrap());
2414+
let len = u32::from_le_bytes(bytes[4..].try_into().unwrap());
2415+
let (ptr, len) = (usize::try_from(ptr)?, usize::try_from(len)?);
2416+
lift_std_map(cx, key_ty, value_ty, ptr, len)
2417+
}
2418+
}
2419+
2420+
#[cfg(feature = "std")]
2421+
fn lift_std_map<K, V>(
2422+
cx: &mut LiftContext<'_>,
2423+
key_ty: InterfaceType,
2424+
value_ty: InterfaceType,
2425+
ptr: usize,
2426+
len: usize,
2427+
) -> Result<std::collections::HashMap<K, V>>
2428+
where
2429+
K: Lift + Eq + Hash,
2430+
V: Lift,
2431+
{
2432+
let tuple_abi = CanonicalAbiInfo::record_static(&[K::ABI, V::ABI]);
2433+
let tuple_size = tuple_abi.size32 as usize;
2434+
let tuple_align = tuple_abi.align32 as usize;
2435+
2436+
match len
2437+
.checked_mul(tuple_size)
2438+
.and_then(|total| ptr.checked_add(total))
2439+
{
2440+
Some(n) if n <= cx.memory().len() => {}
2441+
_ => bail!("map pointer/length out of bounds of memory"),
2442+
}
2443+
if ptr % tuple_align != 0 {
2444+
bail!("map pointer is not aligned");
2445+
}
2446+
2447+
let mut result = std::collections::HashMap::with_capacity(len);
2448+
for i in 0..len {
2449+
let entry_base = ptr + (i * tuple_size);
2450+
2451+
let mut field_offset = 0usize;
2452+
let key_field = K::ABI.next_field32_size(&mut field_offset);
2453+
let key_bytes = &cx.memory()[entry_base + key_field..][..K::SIZE32];
2454+
let key = K::linear_lift_from_memory(cx, key_ty, key_bytes)?;
2455+
2456+
let value_field = V::ABI.next_field32_size(&mut field_offset);
2457+
let value_bytes = &cx.memory()[entry_base + value_field..][..V::SIZE32];
2458+
let value = V::linear_lift_from_memory(cx, value_ty, value_bytes)?;
2459+
2460+
result.insert(key, value);
2461+
}
2462+
2463+
Ok(result)
2464+
}
2465+
22962466
/// Verify that the given wasm type is a tuple with the expected fields in the right order.
22972467
fn typecheck_tuple(
22982468
ty: &InterfaceType,

0 commit comments

Comments
 (0)