Skip to content

Commit abdf8b7

Browse files
committed
Add alternative binary impl example for future refactoring. Debugging copy2d
1 parent a3e5e5c commit abdf8b7

File tree

2 files changed

+104
-17
lines changed

2 files changed

+104
-17
lines changed

candle-core/src/lazy/mod.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -626,11 +626,12 @@ impl Display for OpEdge {
626626
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
627627
write!(
628628
f,
629-
"{} ({:?}, {:?}, {:?})",
629+
"{} ({:?}, {:?}, {:?}, {:?})",
630630
self.id.0,
631631
self.layout.shape().dims(),
632632
self.layout.stride(),
633-
self.dtype
633+
self.dtype,
634+
self.buffer_id
634635
)
635636
}
636637
}

candle-core/src/metal_backend/mod.rs

Lines changed: 101 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1914,6 +1914,75 @@ impl MetalStorage {
19141914
Ok(Self::new(buffer, device.clone(), el_count, dtype))
19151915
}
19161916

1917+
pub fn alt_binary(
1918+
device: MetalDevice,
1919+
op: &'static str,
1920+
lhs_buffer: &Buffer,
1921+
lhs_l: &Layout,
1922+
lhs_dtype: DType,
1923+
rhs_buffer: &Buffer,
1924+
rhs_l: &Layout,
1925+
rhs_dtype: DType,
1926+
dst_buffer: &Buffer,
1927+
) -> Result<()> {
1928+
fn kernel_name(op: &'static str, dtype: &DType, suffix: &str) -> String {
1929+
format!("{op}_{}{}", dtype.as_str(), suffix)
1930+
}
1931+
let shape = lhs_l.shape();
1932+
let el_count = shape.elem_count();
1933+
let encoder = device.command_encoder()?;
1934+
let lhs = buffer_o(lhs_buffer, lhs_l, lhs_dtype);
1935+
let rhs = buffer_o(rhs_buffer, rhs_l, rhs_dtype);
1936+
1937+
let dtype = match op {
1938+
"eq" | "ne" | "le" | "lt" | "ge" | "gt" => DType::U8,
1939+
_ => lhs_dtype,
1940+
};
1941+
let lhs_contiguous = lhs_l.is_contiguous();
1942+
let rhs_contiguous = rhs_l.is_contiguous();
1943+
1944+
if lhs_contiguous && rhs_contiguous {
1945+
let kernel = kernel_name(op, &lhs_dtype, "");
1946+
candle_metal_kernels::call_binary_contiguous(
1947+
&device.device,
1948+
&encoder,
1949+
&device.kernels,
1950+
kernel,
1951+
lhs_dtype.size_in_bytes(),
1952+
el_count,
1953+
lhs,
1954+
rhs,
1955+
dst_buffer,
1956+
)
1957+
.map_err(MetalError::from)?;
1958+
} else {
1959+
let strided_suffix = if lhs_contiguous {
1960+
"_rstrided"
1961+
} else if rhs_contiguous {
1962+
"_lstrided"
1963+
} else {
1964+
"_strided"
1965+
};
1966+
let kernel = kernel_name(op, &lhs_dtype, strided_suffix);
1967+
candle_metal_kernels::call_binary_strided(
1968+
&device.device,
1969+
&encoder,
1970+
&device.kernels,
1971+
kernel,
1972+
lhs_dtype.size_in_bytes(),
1973+
lhs_l.dims(),
1974+
lhs,
1975+
lhs_l.stride(),
1976+
rhs,
1977+
rhs_l.stride(),
1978+
dst_buffer,
1979+
)
1980+
.map_err(MetalError::from)?;
1981+
};
1982+
encoder.set_label("binary");
1983+
Ok(())
1984+
}
1985+
19171986
pub(crate) fn to_cpu<T: Clone>(&self) -> Result<Vec<T>> {
19181987
let size = self.count * self.dtype.size_in_bytes();
19191988
let buffer = self.device.allocate_buffer(size)?;
@@ -2419,9 +2488,18 @@ impl Executor for MetalDevice {
24192488
let lhs_l = lhs_weight.layout();
24202489
let rhs_l = rhs_weight.layout();
24212490

2491+
let lhs_dtype = lhs_weight.dtype();
2492+
let rhs_dtype = rhs_weight.dtype();
2493+
24222494
let lhs_buffer = allocator.get(lhs_weight.buffer_id()).unwrap().clone();
24232495
let rhs_buffer = allocator.get(rhs_weight.buffer_id()).unwrap().clone();
24242496

2497+
let mut outgoing = graph.edges_directed(node, petgraph::Outgoing);
2498+
let dst_edge = outgoing.next().ok_or(InvalidOutgoing(op_node.id()))?;
2499+
let dst_weight = dst_edge.weight().clone();
2500+
let dst_buffer = allocator.get(dst_weight.buffer_id()).unwrap().clone();
2501+
2502+
/*
24252503
let lhs = MetalStorage::new(
24262504
Arc::new(lhs_buffer),
24272505
self.clone(),
@@ -2433,9 +2511,20 @@ impl Executor for MetalDevice {
24332511
self.clone(),
24342512
rhs_weight.layout().shape().elem_count(),
24352513
rhs_weight.dtype(),
2436-
);
2437-
let storage = lhs.binary(kernel, &rhs, lhs_l, rhs_l)?;
2438-
allocator.update_all_outgoing(graph, node, storage.buffer());
2514+
);*/
2515+
MetalStorage::alt_binary(
2516+
self.clone(),
2517+
kernel,
2518+
&lhs_buffer,
2519+
lhs_l,
2520+
lhs_dtype,
2521+
&rhs_buffer,
2522+
rhs_l,
2523+
rhs_dtype,
2524+
&dst_buffer,
2525+
)?;
2526+
//let storage = lhs.binary(kernel, &rhs, lhs_l, rhs_l)?;
2527+
allocator.update_all_outgoing(graph, node, &dst_buffer);
24392528
}
24402529
/*
24412530
ToCpu,
@@ -2611,11 +2700,7 @@ impl Executor for MetalDevice {
26112700
let mut incoming = graph.edges_directed(node, petgraph::Incoming);
26122701
let in_edge = incoming.next().ok_or(InvalidIncoming(op_node.id()))?;
26132702
let in_w = in_edge.weight().clone();
2614-
let buffer = allocator.get_or_allocate(
2615-
in_w.buffer_id(),
2616-
in_w.layout().shape(),
2617-
in_w.dtype(),
2618-
)?;
2703+
let buffer = allocator.get(in_w.buffer_id()).unwrap().clone();
26192704

26202705
let mut outgoing: Vec<(NodeIndex, NodeIndex, OpEdge)> = graph
26212706
.edges_directed(node, petgraph::Outgoing)
@@ -2624,9 +2709,9 @@ impl Executor for MetalDevice {
26242709
})
26252710
.collect();
26262711

2627-
assert!(outgoing.len() > 0);
2712+
println!("{outgoing:?}");
26282713

2629-
outgoing.sort_by(|a, b| a.0.cmp(&b.0));
2714+
assert!(outgoing.len() > 0);
26302715

26312716
/*
26322717
let (_, target, _) = outgoing.first().ok_or(InvalidOutgoing(op_node.id()))?;
@@ -2642,11 +2727,12 @@ impl Executor for MetalDevice {
26422727
let (out_source, out_target, out_w) =
26432728
outgoing.first().ok_or(InvalidOutgoing(op_node.id()))?;
26442729

2645-
let out_buffer = allocator.get_or_allocate(
2646-
out_w.buffer_id(),
2647-
out_w.layout().shape(),
2648-
out_w.dtype(),
2649-
)?;
2730+
let ancestors = crate::lazy::ancestors(&graph, *out_target);
2731+
println!("{}", crate::lazy::graph_to_dot(&&ancestors));
2732+
2733+
let out_buffer = allocator.get(out_w.buffer_id()).unwrap().clone();
2734+
2735+
println!("copy2d {:?} -> {:?}", in_w.buffer_id(), out_w.buffer_id());
26502736

26512737
let src = MetalStorage::new(
26522738
Arc::new(buffer),

0 commit comments

Comments
 (0)