Skip to content

Commit b4ff71b

Browse files
frgossenGoogle-ML-Automation
authored andcommitted
[XLA:GPU] Add JAX-based pipeline parallelism test
PiperOrigin-RevId: 738800315
1 parent 6e2c9b0 commit b4ff71b

File tree

2 files changed

+336
-1
lines changed

2 files changed

+336
-1
lines changed

xla/tests/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2674,6 +2674,7 @@ xla_test(
26742674
":hlo_test_base",
26752675
":literal_test_util",
26762676
":test_macros_header",
2677+
":test_utils",
26772678
":xla_internal_test_main",
26782679
"//xla:error_spec",
26792680
"//xla:literal",

xla/tests/collective_pipeline_parallelism_test.cc

Lines changed: 335 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ limitations under the License.
3232
#include "xla/tests/hlo_test_base.h"
3333
#include "xla/tests/literal_test_util.h"
3434
#include "xla/tests/test_macros.h"
35+
#include "xla/tests/test_utils.h"
3536
#include "tsl/platform/statusor.h"
3637

3738
namespace xla {
@@ -62,6 +63,7 @@ class CollectivePipelineParallelismTest
6263
xla_gpu_experimental_pipeline_parallelism_opt_level_);
6364
debug_options.set_xla_gpu_enable_latency_hiding_scheduler(true);
6465
debug_options.set_xla_gpu_collective_permute_decomposer_threshold(0);
66+
debug_options.set_xla_gpu_autotune_level(0);
6567
config.set_debug_options(debug_options);
6668

6769
return config;
@@ -1239,7 +1241,7 @@ XLA_TEST_P(CollectivePipelineParallelismTest,
12391241
frontend_attributes={_xla_send_recv_source_target_pairs={{0,1}}},
12401242
channel_id=1
12411243
recv_ctx_ = (f32[2,2], u32[], token[]) recv(after_all),
1242-
frontend_attributes={_xla_send_recv_source_target_pairs={{0,1}}},
1244+
frontend_attributes={_xla_send_recv_source_target_pairs={{0,1}}},
12431245
channel_id=2
12441246
init = (u32[], (f32[2,2], u32[], token[]), (f32[2,2], u32[], token[]))
12451247
tuple(i, send_ctx_, recv_ctx_)
@@ -1717,6 +1719,338 @@ XLA_TEST_P(CollectivePipelineParallelismTest,
17171719
ErrorSpec{/*abs_error=*/1e-5, /*rel_error=*/1e-5}));
17181720
}
17191721

1722+
XLA_TEST_P(CollectivePipelineParallelismTest, JaxExampleWithDecomposedCycle) {
1723+
constexpr char kModuleStr[] = R"(
1724+
HloModule jit_entry_computation, entry_computation_layout={
1725+
(f32[4,4096,4096]{2,1,0}, f32[4,5,4096,8192]{3,2,1,0})->
1726+
f32[4,5,4096,8192]{3,2,1,0}},
1727+
allow_spmd_sharding_propagation_to_parameters={false,false},
1728+
allow_spmd_sharding_propagation_to_output={true}, num_partitions=4
1729+
1730+
%_where.10 (Arg_0.11: pred[], Arg_1.12: s32[], Arg_2.13: s32[]) -> s32[] {
1731+
%Arg_0.11 = pred[] parameter(0)
1732+
%Arg_1.12 = s32[] parameter(1)
1733+
%Arg_2.13 = s32[] parameter(2)
1734+
ROOT %select.14 = s32[] select(%Arg_0.11, %Arg_1.12, %Arg_2.13)
1735+
}
1736+
1737+
%remainder.15 (Arg_0.16: s32[], Arg_1.17: s32[]) -> s32[] {
1738+
%Arg_0.16 = s32[] parameter(0)
1739+
%Arg_1.17 = s32[] parameter(1)
1740+
%constant.19 = s32[] constant(0)
1741+
%compare.20 = pred[] compare(%Arg_1.17, %constant.19), direction=EQ
1742+
%constant.18 = s32[] constant(1)
1743+
%call.21 = s32[] call(%compare.20, %constant.18, %Arg_1.17),
1744+
to_apply=%_where.10
1745+
%remainder.22 = s32[] remainder(%Arg_0.16, %call.21)
1746+
%compare.24 = pred[] compare(%remainder.22, %constant.19), direction=LT
1747+
%compare.25 = pred[] compare(%call.21, %constant.19), direction=LT
1748+
%compare.26 = pred[] compare(%compare.24, %compare.25), direction=NE
1749+
%compare.23 = pred[] compare(%remainder.22, %constant.19), direction=NE
1750+
%and.27 = pred[] and(%compare.26, %compare.23)
1751+
%add.28 = s32[] add(%remainder.22, %call.21)
1752+
ROOT %select.29 = s32[] select(%and.27, %add.28, %remainder.22)
1753+
}
1754+
1755+
%_pad.30 (Arg_0.31: f32[4,1,4096,8192], Arg_1.32: s32[]) -> f32[5,1,4096,8192] {
1756+
%Arg_0.31 = f32[4,1,4096,8192]{3,2,1,0} parameter(0)
1757+
%Arg_1.32 = s32[] parameter(1)
1758+
%convert.33 = f32[] convert(%Arg_1.32)
1759+
ROOT %pad.34 = f32[5,1,4096,8192]{3,2,1,0} pad(%Arg_0.31, %convert.33),
1760+
padding=1_0x0_0x0_0x0_0
1761+
}
1762+
1763+
%_where_0.35 (Arg_0.36: pred[], Arg_1.37: f32[4,1,4096,8192],
1764+
Arg_2.38: f32[4,1,4096,8192]) -> f32[4,1,4096,8192] {
1765+
%Arg_0.36 = pred[] parameter(0)
1766+
%broadcast.39 = pred[4,1,4096,8192]{3,2,1,0} broadcast(%Arg_0.36),
1767+
dimensions={}
1768+
%Arg_1.37 = f32[4,1,4096,8192]{3,2,1,0} parameter(1)
1769+
%Arg_2.38 = f32[4,1,4096,8192]{3,2,1,0} parameter(2)
1770+
ROOT %select.40 = f32[4,1,4096,8192]{3,2,1,0} select(%broadcast.39, %Arg_1.37,
1771+
%Arg_2.38)
1772+
}
1773+
1774+
%_where_1.41 (Arg_0.42: pred[4,1,4096,8192], Arg_1.43: f32[4,1,4096,8192],
1775+
Arg_2.44: f32[4,1,4096,8192]) -> f32[4,1,4096,8192] {
1776+
%Arg_0.42 = pred[4,1,4096,8192]{3,2,1,0} parameter(0)
1777+
%Arg_1.43 = f32[4,1,4096,8192]{3,2,1,0} parameter(1)
1778+
%Arg_2.44 = f32[4,1,4096,8192]{3,2,1,0} parameter(2)
1779+
ROOT %select.45 = f32[4,1,4096,8192]{3,2,1,0} select(%Arg_0.42, %Arg_1.43,
1780+
%Arg_2.44)
1781+
}
1782+
1783+
%_where.46 (Arg_0.47: pred[], Arg_1.48: s32[], Arg_2.49: s32[]) -> s32[] {
1784+
%Arg_0.47 = pred[] parameter(0)
1785+
%Arg_1.48 = s32[] parameter(1)
1786+
%Arg_2.49 = s32[] parameter(2)
1787+
ROOT %select.50 = s32[] select(%Arg_0.47, %Arg_1.48, %Arg_2.49)
1788+
}
1789+
1790+
%remainder.51 (Arg_0.52: s32[], Arg_1.53: s32[]) -> s32[] {
1791+
%Arg_0.52 = s32[] parameter(0)
1792+
%Arg_1.53 = s32[] parameter(1)
1793+
%constant.55 = s32[] constant(0)
1794+
%compare.56 = pred[] compare(%Arg_1.53, %constant.55), direction=EQ
1795+
%constant.54 = s32[] constant(1)
1796+
%call.57 = s32[] call(%compare.56, %constant.54, %Arg_1.53),
1797+
to_apply=%_where.46
1798+
%remainder.58 = s32[] remainder(%Arg_0.52, %call.57)
1799+
%compare.60 = pred[] compare(%remainder.58, %constant.55), direction=LT
1800+
%compare.61 = pred[] compare(%call.57, %constant.55), direction=LT
1801+
%compare.62 = pred[] compare(%compare.60, %compare.61), direction=NE
1802+
%compare.59 = pred[] compare(%remainder.58, %constant.55), direction=NE
1803+
%and.63 = pred[] and(%compare.62, %compare.59)
1804+
%add.64 = s32[] add(%remainder.58, %call.57)
1805+
ROOT %select.65 = s32[] select(%and.63, %add.64, %remainder.58)
1806+
}
1807+
1808+
%_pad_2.66 (Arg_0.67: f32[4,1,4096,8192], Arg_1.68: s32[])
1809+
-> f32[7,1,4096,8192] {
1810+
%Arg_0.67 = f32[4,1,4096,8192]{3,2,1,0} parameter(0)
1811+
%Arg_1.68 = s32[] parameter(1)
1812+
%convert.69 = f32[] convert(%Arg_1.68)
1813+
ROOT %pad.70 = f32[7,1,4096,8192]{3,2,1,0} pad(%Arg_0.67, %convert.69),
1814+
padding=0_3x0_0x0_0x0_0
1815+
}
1816+
1817+
%_where.71 (Arg_0.72: pred[], Arg_1.73: s32[], Arg_2.74: s32[]) -> s32[] {
1818+
%Arg_0.72 = pred[] parameter(0)
1819+
%Arg_1.73 = s32[] parameter(1)
1820+
%Arg_2.74 = s32[] parameter(2)
1821+
ROOT %select.75 = s32[] select(%Arg_0.72, %Arg_1.73, %Arg_2.74)
1822+
}
1823+
1824+
%remainder.76 (Arg_0.77: s32[], Arg_1.78: s32[]) -> s32[] {
1825+
%Arg_0.77 = s32[] parameter(0)
1826+
%Arg_1.78 = s32[] parameter(1)
1827+
%constant.80 = s32[] constant(0)
1828+
%compare.81 = pred[] compare(%Arg_1.78, %constant.80), direction=EQ
1829+
%constant.79 = s32[] constant(1)
1830+
%call.82 = s32[] call(%compare.81, %constant.79, %Arg_1.78),
1831+
to_apply=%_where.71
1832+
%remainder.83 = s32[] remainder(%Arg_0.77, %call.82)
1833+
%compare.85 = pred[] compare(%remainder.83, %constant.80), direction=LT
1834+
%compare.86 = pred[] compare(%call.82, %constant.80), direction=LT
1835+
%compare.87 = pred[] compare(%compare.85, %compare.86), direction=NE
1836+
%compare.84 = pred[] compare(%remainder.83, %constant.80), direction=NE
1837+
%and.88 = pred[] and(%compare.87, %compare.84)
1838+
%add.89 = s32[] add(%remainder.83, %call.82)
1839+
ROOT %select.90 = s32[] select(%and.88, %add.89, %remainder.83)
1840+
}
1841+
1842+
%None.91 (Arg_0.92: f32[4,4096,4096], Arg_1.93: f32[4,5,4096,8192],
1843+
Arg_2.94: f32[4,5,4096,8192], Arg_3.95: f32[4,1,4096,8192],
1844+
Arg_4.96: f32[4,1,4096,8192], Arg_5.97: s32[])
1845+
-> (f32[4,4096,4096], f32[4,5,4096,8192], f32[4,5,4096,8192],
1846+
f32[4,1,4096,8192], f32[4,1,4096,8192]) {
1847+
%Arg_0.92 = f32[4,4096,4096]{2,1,0} parameter(0)
1848+
%Arg_1.93 = f32[4,5,4096,8192]{3,2,1,0} parameter(1)
1849+
%Arg_2.94 = f32[4,5,4096,8192]{3,2,1,0} parameter(2)
1850+
%iota.113 = s32[4]{0} iota(), iota_dimension=0
1851+
%broadcast.114 = s32[4,1,4096,8192]{3,2,1,0} broadcast(%iota.113),
1852+
dimensions={0}
1853+
%constant.98 = s32[] constant(0)
1854+
%broadcast.99 = s32[4,1,4096,8192]{3,2,1,0} broadcast(%constant.98),
1855+
dimensions={}
1856+
%compare.115 = pred[4,1,4096,8192]{3,2,1,0} compare(%broadcast.114,
1857+
%broadcast.99), direction=EQ
1858+
%Arg_5.97 = s32[] parameter(5)
1859+
%constant.102 = s32[] constant(5)
1860+
%compare.111 = pred[] compare(%Arg_5.97, %constant.102), direction=LT
1861+
%constant.103 = s32[] constant(0)
1862+
%call.104 = s32[] call(%Arg_5.97, %constant.102), to_apply=%remainder.15
1863+
%compare.105 = pred[] compare(%call.104, %constant.103), direction=LT
1864+
%add.106 = s32[] add(%call.104, %constant.102)
1865+
%select.107 = s32[] select(%compare.105, %add.106, %call.104)
1866+
%dynamic-slice.108 = f32[4,1,4096,8192]{3,2,1,0} dynamic-slice(%Arg_1.93,
1867+
%constant.103, %select.107, %constant.103, %constant.103),
1868+
dynamic_slice_sizes={4,1,4096,8192}
1869+
%Arg_4.96 = f32[4,1,4096,8192]{3,2,1,0} parameter(4)
1870+
%call.112 = f32[4,1,4096,8192]{3,2,1,0} call(%compare.111, %dynamic-slice.108,
1871+
%Arg_4.96), to_apply=%_where_0.35
1872+
%Arg_3.95 = f32[4,1,4096,8192]{3,2,1,0} parameter(3)
1873+
%call.109 = f32[5,1,4096,8192]{3,2,1,0} call(%Arg_3.95, %constant.103),
1874+
to_apply=%_pad.30
1875+
%slice.110 = f32[4,1,4096,8192]{3,2,1,0} slice(%call.109),
1876+
slice={[0:4], [0:1], [0:4096], [0:8192]}
1877+
%call.116 = f32[4,1,4096,8192]{3,2,1,0} call(%compare.115, %call.112,
1878+
%slice.110), to_apply=%_where_1.41
1879+
%reshape.117 = f32[4,4096,8192]{2,1,0} reshape(%call.116)
1880+
%dot.118 = f32[4,4096,8192]{2,1,0} dot(%Arg_0.92, %reshape.117),
1881+
lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0},
1882+
rhs_contracting_dims={1}
1883+
%reshape.150 = f32[4,1,4096,8192]{3,2,1,0} reshape(%dot.118)
1884+
%constant.100 = s32[] constant(2)
1885+
%add.159 = s32[] add(%Arg_5.97, %constant.100)
1886+
%call.160 = s32[] call(%add.159, %constant.102), to_apply=%remainder.76
1887+
%compare.161 = pred[] compare(%call.160, %constant.103), direction=LT
1888+
%add.162 = s32[] add(%call.160, %constant.102)
1889+
%select.163 = s32[] select(%compare.161, %add.162, %call.160)
1890+
%dynamic-update-slice.164 = f32[4,5,4096,8192]{3,2,1,0}
1891+
dynamic-update-slice(%Arg_2.94, %reshape.150, %constant.103, %select.163,
1892+
%constant.103, /*index=5*/%constant.103)
1893+
%constant.101 = s32[] constant(1)
1894+
%add.151 = s32[] add(%Arg_5.97, %constant.101)
1895+
%call.152 = s32[] call(%add.151, %constant.102), to_apply=%remainder.51
1896+
%compare.153 = pred[] compare(%call.152, %constant.103), direction=LT
1897+
%add.154 = s32[] add(%call.152, %constant.102)
1898+
%select.155 = s32[] select(%compare.153, %add.154, %call.152)
1899+
%dynamic-slice.156 = f32[4,1,4096,8192]{3,2,1,0} dynamic-slice(%Arg_2.94,
1900+
%constant.103, %select.155, %constant.103, %constant.103),
1901+
dynamic_slice_sizes={4,1,4096,8192}
1902+
%call.157 = f32[7,1,4096,8192]{3,2,1,0} call(%dynamic-slice.156,
1903+
%constant.103), to_apply=%_pad_2.66
1904+
%slice.158 = f32[4,1,4096,8192]{3,2,1,0} slice(%call.157),
1905+
slice={[3:7], [0:1], [0:4096], [0:8192]}
1906+
ROOT %tuple.165 = (f32[4,4096,4096]{2,1,0}, f32[4,5,4096,8192]{3,2,1,0},
1907+
f32[4,5,4096,8192]{3,2,1,0}, f32[4,1,4096,8192]{3,2,1,0},
1908+
f32[4,1,4096,8192]{3,2,1,0}) tuple(%Arg_0.92, %Arg_1.93,
1909+
%dynamic-update-slice.164, %reshape.150, %slice.158)
1910+
}
1911+
1912+
%region_0.166 (arg_tuple.167: (s32[], f32[4,4096,4096], f32[4,5,4096,8192],
1913+
f32[4,5,4096,8192], f32[4,1,4096,8192], /*index=5*/f32[4,1,4096,8192],
1914+
s32[13])) -> (s32[], f32[4,4096,4096], f32[4,5,4096,8192],
1915+
f32[4,5,4096,8192], f32[4,1,4096,8192], /*index=5*/f32[4,1,4096,8192],
1916+
s32[13]) {
1917+
%arg_tuple.167 = (s32[], f32[4,4096,4096]{2,1,0}, f32[4,5,4096,8192]{3,2,1,0},
1918+
f32[4,5,4096,8192]{3,2,1,0}, f32[4,1,4096,8192]{3,2,1,0},
1919+
/*index=5*/f32[4,1,4096,8192]{3,2,1,0}, s32[13]{0}) parameter(0)
1920+
%get-tuple-element.168 = s32[] get-tuple-element(%arg_tuple.167), index=0
1921+
%constant.175 = s32[] constant(1)
1922+
%add.184 = s32[] add(%get-tuple-element.168, %constant.175)
1923+
%get-tuple-element.169 = f32[4,4096,4096]{2,1,0}
1924+
get-tuple-element(%arg_tuple.167), index=1
1925+
%get-tuple-element.170 = f32[4,5,4096,8192]{3,2,1,0}
1926+
get-tuple-element(%arg_tuple.167), index=2
1927+
%get-tuple-element.171 = f32[4,5,4096,8192]{3,2,1,0}
1928+
get-tuple-element(%arg_tuple.167), index=3
1929+
%get-tuple-element.172 = f32[4,1,4096,8192]{3,2,1,0}
1930+
get-tuple-element(%arg_tuple.167), index=4
1931+
%get-tuple-element.173 = f32[4,1,4096,8192]{3,2,1,0}
1932+
get-tuple-element(%arg_tuple.167), index=5
1933+
%get-tuple-element.174 = s32[13]{0} get-tuple-element(%arg_tuple.167), index=6
1934+
%dynamic-slice.176 = s32[1]{0} dynamic-slice(%get-tuple-element.174,
1935+
%get-tuple-element.168), dynamic_slice_sizes={1}
1936+
%reshape.177 = s32[] reshape(%dynamic-slice.176)
1937+
%call.178 = (f32[4,4096,4096]{2,1,0}, f32[4,5,4096,8192]{3,2,1,0},
1938+
f32[4,5,4096,8192]{3,2,1,0}, f32[4,1,4096,8192]{3,2,1,0},
1939+
f32[4,1,4096,8192]{3,2,1,0}) call(%get-tuple-element.169,
1940+
%get-tuple-element.170, %get-tuple-element.171, %get-tuple-element.172,
1941+
%get-tuple-element.173, /*index=5*/%reshape.177), to_apply=%None.91
1942+
%get-tuple-element.179 = f32[4,4096,4096]{2,1,0} get-tuple-element(%call.178),
1943+
index=0
1944+
%get-tuple-element.180 = f32[4,5,4096,8192]{3,2,1,0}
1945+
get-tuple-element(%call.178), index=1
1946+
%get-tuple-element.181 = f32[4,5,4096,8192]{3,2,1,0}
1947+
get-tuple-element(%call.178), index=2
1948+
%get-tuple-element.182 = f32[4,1,4096,8192]{3,2,1,0}
1949+
get-tuple-element(%call.178), index=3
1950+
%get-tuple-element.183 = f32[4,1,4096,8192]{3,2,1,0}
1951+
get-tuple-element(%call.178), index=4
1952+
ROOT %tuple.185 = (s32[], f32[4,4096,4096]{2,1,0},
1953+
f32[4,5,4096,8192]{3,2,1,0}, f32[4,5,4096,8192]{3,2,1,0},
1954+
f32[4,1,4096,8192]{3,2,1,0}, /*index=5*/f32[4,1,4096,8192]{3,2,1,0},
1955+
s32[13]{0}) tuple(%add.184, %get-tuple-element.179,
1956+
%get-tuple-element.180, %get-tuple-element.181, %get-tuple-element.182,
1957+
/*index=5*/%get-tuple-element.183, %get-tuple-element.174)
1958+
}
1959+
1960+
%region_1.186 (arg_tuple.187: (s32[], f32[4,4096,4096], f32[4,5,4096,8192],
1961+
f32[4,5,4096,8192], f32[4,1,4096,8192], /*index=5*/f32[4,1,4096,8192],
1962+
s32[13])) -> pred[] {
1963+
%arg_tuple.187 = (s32[], f32[4,4096,4096]{2,1,0}, f32[4,5,4096,8192]{3,2,1,0},
1964+
f32[4,5,4096,8192]{3,2,1,0}, f32[4,1,4096,8192]{3,2,1,0},
1965+
/*index=5*/f32[4,1,4096,8192]{3,2,1,0}, s32[13]{0}) parameter(0)
1966+
%get-tuple-element.189 = f32[4,4096,4096]{2,1,0}
1967+
get-tuple-element(%arg_tuple.187), index=1
1968+
%get-tuple-element.190 = f32[4,5,4096,8192]{3,2,1,0}
1969+
get-tuple-element(%arg_tuple.187), index=2
1970+
%get-tuple-element.191 = f32[4,5,4096,8192]{3,2,1,0}
1971+
get-tuple-element(%arg_tuple.187), index=3
1972+
%get-tuple-element.192 = f32[4,1,4096,8192]{3,2,1,0}
1973+
get-tuple-element(%arg_tuple.187), index=4
1974+
%get-tuple-element.193 = f32[4,1,4096,8192]{3,2,1,0}
1975+
get-tuple-element(%arg_tuple.187), index=5
1976+
%get-tuple-element.194 = s32[13]{0} get-tuple-element(%arg_tuple.187), index=6
1977+
%get-tuple-element.188 = s32[] get-tuple-element(%arg_tuple.187), index=0
1978+
%constant.195 = s32[] constant(13)
1979+
ROOT %compare.196 = pred[] compare(%get-tuple-element.188, %constant.195),
1980+
direction=LT
1981+
}
1982+
1983+
ENTRY %main.204 (Arg_0.1: f32[4,4096,4096], Arg_1.2: f32[4,5,4096,8192])
1984+
-> f32[4,5,4096,8192] {
1985+
%constant.3 = s32[] constant(0)
1986+
%Arg_0.1 = f32[4,4096,4096]{2,1,0} parameter(0),
1987+
sharding={devices=[4,1,1]<=[4]}
1988+
%Arg_1.2 = f32[4,5,4096,8192]{3,2,1,0} parameter(1),
1989+
sharding={devices=[4,1,1,1]<=[4]}
1990+
%constant.4 = f32[] constant(0)
1991+
%broadcast.5 = f32[4,5,4096,8192]{3,2,1,0} broadcast(%constant.4),
1992+
dimensions={}
1993+
%constant.6 = f32[] constant(0)
1994+
%broadcast.7 = f32[4,1,4096,8192]{3,2,1,0} broadcast(%constant.6),
1995+
dimensions={}
1996+
%iota.8 = s32[13]{0} iota(), iota_dimension=0
1997+
%tuple.9 = (s32[], f32[4,4096,4096]{2,1,0}, f32[4,5,4096,8192]{3,2,1,0},
1998+
f32[4,5,4096,8192]{3,2,1,0}, f32[4,1,4096,8192]{3,2,1,0},
1999+
/*index=5*/f32[4,1,4096,8192]{3,2,1,0}, s32[13]{0}) tuple(%constant.3,
2000+
%Arg_0.1, %Arg_1.2, %broadcast.5, %broadcast.7, /*index=5*/%broadcast.7,
2001+
%iota.8)
2002+
%while.197 = (s32[], f32[4,4096,4096]{2,1,0}, f32[4,5,4096,8192]{3,2,1,0},
2003+
f32[4,5,4096,8192]{3,2,1,0}, f32[4,1,4096,8192]{3,2,1,0},
2004+
/*index=5*/f32[4,1,4096,8192]{3,2,1,0}, s32[13]{0}) while(%tuple.9),
2005+
condition=%region_1.186, body=%region_0.166
2006+
%get-tuple-element.198 = s32[] get-tuple-element(%while.197), index=0
2007+
%get-tuple-element.199 = f32[4,4096,4096]{2,1,0}
2008+
get-tuple-element(%while.197), index=1
2009+
%get-tuple-element.200 = f32[4,5,4096,8192]{3,2,1,0}
2010+
get-tuple-element(%while.197), index=2
2011+
ROOT %get-tuple-element.201 = f32[4,5,4096,8192]{3,2,1,0}
2012+
get-tuple-element(%while.197), index=3
2013+
%get-tuple-element.202 = f32[4,1,4096,8192]{3,2,1,0}
2014+
get-tuple-element(%while.197), index=4
2015+
%get-tuple-element.203 = f32[4,1,4096,8192]{3,2,1,0}
2016+
get-tuple-element(%while.197), index=5
2017+
}
2018+
)";
2019+
2020+
const int64_t kNumReplicas = 1;
2021+
const int64_t kNumPartitions = 4;
2022+
if (test_runner().device_count() < kNumReplicas * kNumPartitions) {
2023+
GTEST_SKIP() << "Test requires at least " << kNumReplicas * kNumPartitions
2024+
<< " devices (" << test_runner().device_count()
2025+
<< " available)";
2026+
}
2027+
2028+
HloModuleConfig config = GetModuleConfigForTest(
2029+
/*replica_count=*/kNumReplicas, /*num_partitions=*/kNumPartitions);
2030+
TF_ASSERT_OK_AND_ASSIGN(auto module,
2031+
ParseAndReturnVerifiedModule(kModuleStr, config));
2032+
2033+
// Create device assignment running across partitions.
2034+
DeviceAssignment device_assignment(/*replica_count=*/kNumReplicas,
2035+
/*computation_count=*/kNumPartitions);
2036+
for (int64_t i = 0; i < kNumPartitions; ++i) {
2037+
device_assignment(0, i) = i;
2038+
}
2039+
2040+
TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> fake_args,
2041+
MakeFakeArguments(module.get()));
2042+
std::vector<Literal *> args;
2043+
for (auto &arg : fake_args) {
2044+
args.push_back(&arg);
2045+
}
2046+
TF_ASSERT_OK_AND_ASSIGN(
2047+
std::vector<Literal> results,
2048+
ExecuteReplicated(std::move(module), args,
2049+
/*num_replicas=*/kNumPartitions, &device_assignment,
2050+
/*run_hlo_passes=*/true, /*use_threads=*/true));
2051+
ASSERT_EQ(results.size(), kNumPartitions);
2052+
}
2053+
17202054
INSTANTIATE_TEST_SUITE_P(
17212055
CollectivePipelineParallelismTestWithAndWithoutOpts,
17222056
CollectivePipelineParallelismTest,

0 commit comments

Comments
 (0)