@@ -12,39 +12,86 @@ tt.func @get_rank(
1212}
1313
1414tt.func @get_peer_ptr (
15- %arg0: !tt.ptr <i64 >, %peer_id: i64 , %metadata: !tt.ptr <i64 >
15+ %arg0: !tt.ptr <i64 >, %arg1: !tt.ptr < i64 >, % peer_id: i64 , %metadata: !tt.ptr <i64 >
1616) -> !tt.ptr <i64 > {
1717 // CHECK-NOT: triton_xla.get_peer_ptr
18- // Byte size of a pointer.
18+ // An offset from the beginning of metadata to the peer pointers for the %arg1
19+ // offset(param_to_peers) + sizeof(uint64_t) * 2 = 20
20+ // CHECK: %c24_i64 = arith.constant 24 : i64
21+ // Size of the uint64_t.
1922 // CHECK: %c8_i64 = arith.constant 8 : i64
2023
2124 // Load metadata->rank
22- // CHECK-NEXT: %0 = tt.load %arg2 : !tt.ptr<i64>
25+ // CHECK-NEXT: %0 = tt.load %arg3 : !tt.ptr<i64>
2326
2427 // Calculate offset to current base pointer.
2528 // CHECK-NEXT: %1 = arith.muli %0, %c8_i64 : i64
2629
27- // Load metadata->buffer_root_ptrs[metadata->rank].
30+ // Load metadata->param_to_peers[argument_offset + metadata->rank].
31+ // Here argument_offset = 0 since %arg0 is the first argument.
2832 // CHECK-NEXT: %2 = arith.addi %1, %c8_i64 : i64
29- // CHECK-NEXT: %3 = tt.addptr %arg2 , %2 : !tt.ptr<i64>, i64
33+ // CHECK-NEXT: %3 = tt.addptr %arg3 , %2 : !tt.ptr<i64>, i64
3034 // CHECK-NEXT: %4 = tt.load %3 : !tt.ptr<i64>
3135
3236 // Calculate offset to address.
3337 // CHECK-NEXT: %5 = tt.ptr_to_int %arg0 : !tt.ptr<i64> -> i64
3438 // CHECK-NEXT: %6 = arith.subi %5, %4 : i64
3539
3640 // Calculate offset to peer base pointer.
37- // CHECK-NEXT: %7 = arith.muli %arg1 , %c8_i64 : i64
41+ // CHECK-NEXT: %7 = arith.muli %arg2 , %c8_i64 : i64
3842 // CHECK-NEXT: %8 = arith.addi %7, %c8_i64 : i64
3943
40- // Load metadata->buffer_root_ptrs[ peer_id].
41- // CHECK-NEXT: %9 = tt.addptr %arg2 , %8 : !tt.ptr<i64>, i64
44+ // Load metadata->peer_base_ptrs[argument_offset + peer_id].
45+ // CHECK-NEXT: %9 = tt.addptr %arg3 , %8 : !tt.ptr<i64>, i64
4246 // CHECK-NEXT: %10 = tt.load %9 : !tt.ptr<i64>
4347
44- // Load metadata->buffer_root_ptrs[peer_id] + offset.
48+ // Load metadata->buffer_root_ptrs[argument_offset + peer_id] + offset.
4549 // CHECK-NEXT: %11 = arith.addi %10, %6 : i64
4650 // CHECK-NEXT: %12 = tt.int_to_ptr %11 : i64 -> !tt.ptr<i64>
47- // CHECK-NEXT: tt.return %12 : !tt.ptr<i64>
48- %peer_ptr = triton_xla.get_peer_ptr %arg0 , %peer_id , %metadata : (!tt.ptr <i64 >, i64 , !tt.ptr <i64 >) -> !tt.ptr <i64 >
49- tt.return %peer_ptr : !tt.ptr <i64 >
51+ %arg_0_peer_ptr = triton_xla.get_peer_ptr %arg0 , %peer_id , %metadata ,
52+ { argument_index = 0 : i32 , world_size = 2 : i32 } :
53+ (!tt.ptr <i64 >, i64 , !tt.ptr <i64 >) -> !tt.ptr <i64 >
54+
55+ // Load metadata->rank
56+ // CHECK-NEXT: %13 = tt.load %arg3 : !tt.ptr<i64>
57+ // Calculate offset to current base pointer.
58+ // CHECK-NEXT: %14 = arith.muli %13, %c8_i64 : i64
59+ // Load metadata->param_to_peers[argument_offset + metadata->rank].
60+ // CHECK-NEXT: %15 = arith.addi %14, %c24_i64 : i64
61+ // CHECK-NEXT: %16 = tt.addptr %arg3, %15 : !tt.ptr<i64>, i64
62+ // CHECK-NEXT: %17 = tt.load %16 : !tt.ptr<i64>
63+ // Calculate offset to address.
64+ // CHECK-NEXT: %18 = tt.ptr_to_int %arg1 : !tt.ptr<i64> -> i64
65+ // CHECK-NEXT: %19 = arith.subi %18, %17 : i64
66+
67+ // Calculate offset to peer base pointer.
68+ // CHECK-NEXT: %20 = arith.muli %arg2, %c8_i64 : i64
69+ // CHECK-NEXT: %21 = arith.addi %20, %c24_i64 : i64
70+
71+ // Load metadata->peer_base_ptrs[argument_offset + peer_id].
72+ // CHECK-NEXT: %22 = tt.addptr %arg3, %21 : !tt.ptr<i64>, i64
73+ // CHECK-NEXT: %23 = tt.load %22 : !tt.ptr<i64>
74+
75+ // Load metadata->buffer_root_ptrs[argument_offset + peer_id] + offset.
76+ // CHECK-NEXT: %24 = arith.addi %23, %19 : i64
77+ // CHECK-NEXT: %25 = tt.int_to_ptr %24 : i64 -> !tt.ptr<i64>
78+
79+ %arg_1_peer_ptr = triton_xla.get_peer_ptr %arg1 , %peer_id , %metadata ,
80+ { argument_index = 1 : i32 , world_size = 2 : i32 } :
81+ (!tt.ptr <i64 >, i64 , !tt.ptr <i64 >) -> !tt.ptr <i64 >
82+
83+ // Avoid optimizing away the get_peer_ptr calls, by returning xor of the two
84+ // peer pointers.
85+ //
86+ // CHECK-NEXT: %26 = tt.ptr_to_int %12 : !tt.ptr<i64> -> i64
87+ %int_arg0 = tt.ptr_to_int %arg_0_peer_ptr : !tt.ptr <i64 > -> i64
88+ // CHECK-NEXT: %27 = tt.ptr_to_int %25 : !tt.ptr<i64> -> i64
89+ %int_arg1 = tt.ptr_to_int %arg_1_peer_ptr : !tt.ptr <i64 > -> i64
90+
91+ // CHECK-NEXT: %28 = arith.ori %26, %27 : i64
92+ %result_int = arith.ori %int_arg0 , %int_arg1 : i64
93+ // CHECK-NEXT: %29 = tt.int_to_ptr %28 : i64 -> !tt.ptr<i64>
94+ %result_ptr = tt.int_to_ptr %result_int : i64 -> !tt.ptr <i64 >
95+ // CHECK-NEXT: tt.return %29 : !tt.ptr<i64>
96+ tt.return %result_ptr : !tt.ptr <i64 >
5097}
0 commit comments