@@ -1914,6 +1914,75 @@ impl MetalStorage {
19141914 Ok ( Self :: new ( buffer, device. clone ( ) , el_count, dtype) )
19151915 }
19161916
1917+ pub fn alt_binary (
1918+ device : MetalDevice ,
1919+ op : & ' static str ,
1920+ lhs_buffer : & Buffer ,
1921+ lhs_l : & Layout ,
1922+ lhs_dtype : DType ,
1923+ rhs_buffer : & Buffer ,
1924+ rhs_l : & Layout ,
1925+ rhs_dtype : DType ,
1926+ dst_buffer : & Buffer ,
1927+ ) -> Result < ( ) > {
1928+ fn kernel_name ( op : & ' static str , dtype : & DType , suffix : & str ) -> String {
1929+ format ! ( "{op}_{}{}" , dtype. as_str( ) , suffix)
1930+ }
1931+ let shape = lhs_l. shape ( ) ;
1932+ let el_count = shape. elem_count ( ) ;
1933+ let encoder = device. command_encoder ( ) ?;
1934+ let lhs = buffer_o ( lhs_buffer, lhs_l, lhs_dtype) ;
1935+ let rhs = buffer_o ( rhs_buffer, rhs_l, rhs_dtype) ;
1936+
1937+ let dtype = match op {
1938+ "eq" | "ne" | "le" | "lt" | "ge" | "gt" => DType :: U8 ,
1939+ _ => lhs_dtype,
1940+ } ;
1941+ let lhs_contiguous = lhs_l. is_contiguous ( ) ;
1942+ let rhs_contiguous = rhs_l. is_contiguous ( ) ;
1943+
1944+ if lhs_contiguous && rhs_contiguous {
1945+ let kernel = kernel_name ( op, & lhs_dtype, "" ) ;
1946+ candle_metal_kernels:: call_binary_contiguous (
1947+ & device. device ,
1948+ & encoder,
1949+ & device. kernels ,
1950+ kernel,
1951+ lhs_dtype. size_in_bytes ( ) ,
1952+ el_count,
1953+ lhs,
1954+ rhs,
1955+ dst_buffer,
1956+ )
1957+ . map_err ( MetalError :: from) ?;
1958+ } else {
1959+ let strided_suffix = if lhs_contiguous {
1960+ "_rstrided"
1961+ } else if rhs_contiguous {
1962+ "_lstrided"
1963+ } else {
1964+ "_strided"
1965+ } ;
1966+ let kernel = kernel_name ( op, & lhs_dtype, strided_suffix) ;
1967+ candle_metal_kernels:: call_binary_strided (
1968+ & device. device ,
1969+ & encoder,
1970+ & device. kernels ,
1971+ kernel,
1972+ lhs_dtype. size_in_bytes ( ) ,
1973+ lhs_l. dims ( ) ,
1974+ lhs,
1975+ lhs_l. stride ( ) ,
1976+ rhs,
1977+ rhs_l. stride ( ) ,
1978+ dst_buffer,
1979+ )
1980+ . map_err ( MetalError :: from) ?;
1981+ } ;
1982+ encoder. set_label ( "binary" ) ;
1983+ Ok ( ( ) )
1984+ }
1985+
19171986 pub ( crate ) fn to_cpu < T : Clone > ( & self ) -> Result < Vec < T > > {
19181987 let size = self . count * self . dtype . size_in_bytes ( ) ;
19191988 let buffer = self . device . allocate_buffer ( size) ?;
@@ -2419,9 +2488,18 @@ impl Executor for MetalDevice {
24192488 let lhs_l = lhs_weight. layout ( ) ;
24202489 let rhs_l = rhs_weight. layout ( ) ;
24212490
2491+ let lhs_dtype = lhs_weight. dtype ( ) ;
2492+ let rhs_dtype = rhs_weight. dtype ( ) ;
2493+
24222494 let lhs_buffer = allocator. get ( lhs_weight. buffer_id ( ) ) . unwrap ( ) . clone ( ) ;
24232495 let rhs_buffer = allocator. get ( rhs_weight. buffer_id ( ) ) . unwrap ( ) . clone ( ) ;
24242496
2497+ let mut outgoing = graph. edges_directed ( node, petgraph:: Outgoing ) ;
2498+ let dst_edge = outgoing. next ( ) . ok_or ( InvalidOutgoing ( op_node. id ( ) ) ) ?;
2499+ let dst_weight = dst_edge. weight ( ) . clone ( ) ;
2500+ let dst_buffer = allocator. get ( dst_weight. buffer_id ( ) ) . unwrap ( ) . clone ( ) ;
2501+
2502+ /*
24252503 let lhs = MetalStorage::new(
24262504 Arc::new(lhs_buffer),
24272505 self.clone(),
@@ -2433,9 +2511,20 @@ impl Executor for MetalDevice {
24332511 self.clone(),
24342512 rhs_weight.layout().shape().elem_count(),
24352513 rhs_weight.dtype(),
2436- ) ;
2437- let storage = lhs. binary ( kernel, & rhs, lhs_l, rhs_l) ?;
2438- allocator. update_all_outgoing ( graph, node, storage. buffer ( ) ) ;
2514+ );*/
2515+ MetalStorage :: alt_binary (
2516+ self . clone ( ) ,
2517+ kernel,
2518+ & lhs_buffer,
2519+ lhs_l,
2520+ lhs_dtype,
2521+ & rhs_buffer,
2522+ rhs_l,
2523+ rhs_dtype,
2524+ & dst_buffer,
2525+ ) ?;
2526+ //let storage = lhs.binary(kernel, &rhs, lhs_l, rhs_l)?;
2527+ allocator. update_all_outgoing ( graph, node, & dst_buffer) ;
24392528 }
24402529 /*
24412530 ToCpu,
@@ -2611,11 +2700,7 @@ impl Executor for MetalDevice {
26112700 let mut incoming = graph. edges_directed ( node, petgraph:: Incoming ) ;
26122701 let in_edge = incoming. next ( ) . ok_or ( InvalidIncoming ( op_node. id ( ) ) ) ?;
26132702 let in_w = in_edge. weight ( ) . clone ( ) ;
2614- let buffer = allocator. get_or_allocate (
2615- in_w. buffer_id ( ) ,
2616- in_w. layout ( ) . shape ( ) ,
2617- in_w. dtype ( ) ,
2618- ) ?;
2703+ let buffer = allocator. get ( in_w. buffer_id ( ) ) . unwrap ( ) . clone ( ) ;
26192704
26202705 let mut outgoing: Vec < ( NodeIndex , NodeIndex , OpEdge ) > = graph
26212706 . edges_directed ( node, petgraph:: Outgoing )
@@ -2624,9 +2709,9 @@ impl Executor for MetalDevice {
26242709 } )
26252710 . collect ( ) ;
26262711
2627- assert ! ( outgoing. len ( ) > 0 ) ;
2712+ println ! ( "{ outgoing:?}" ) ;
26282713
2629- outgoing. sort_by ( |a , b| a . 0 . cmp ( & b . 0 ) ) ;
2714+ assert ! ( outgoing. len ( ) > 0 ) ;
26302715
26312716 /*
26322717 let (_, target, _) = outgoing.first().ok_or(InvalidOutgoing(op_node.id()))?;
@@ -2642,11 +2727,12 @@ impl Executor for MetalDevice {
26422727 let ( out_source, out_target, out_w) =
26432728 outgoing. first ( ) . ok_or ( InvalidOutgoing ( op_node. id ( ) ) ) ?;
26442729
2645- let out_buffer = allocator. get_or_allocate (
2646- out_w. buffer_id ( ) ,
2647- out_w. layout ( ) . shape ( ) ,
2648- out_w. dtype ( ) ,
2649- ) ?;
2730+ let ancestors = crate :: lazy:: ancestors ( & graph, * out_target) ;
2731+ println ! ( "{}" , crate :: lazy:: graph_to_dot( &&ancestors) ) ;
2732+
2733+ let out_buffer = allocator. get ( out_w. buffer_id ( ) ) . unwrap ( ) . clone ( ) ;
2734+
2735+ println ! ( "copy2d {:?} -> {:?}" , in_w. buffer_id( ) , out_w. buffer_id( ) ) ;
26502736
26512737 let src = MetalStorage :: new (
26522738 Arc :: new ( buffer) ,
0 commit comments