@@ -32,6 +32,7 @@ limitations under the License.
3232#include " xla/tests/hlo_test_base.h"
3333#include " xla/tests/literal_test_util.h"
3434#include " xla/tests/test_macros.h"
35+ #include " xla/tests/test_utils.h"
3536#include " tsl/platform/statusor.h"
3637
3738namespace xla {
@@ -62,6 +63,7 @@ class CollectivePipelineParallelismTest
6263 xla_gpu_experimental_pipeline_parallelism_opt_level_);
6364 debug_options.set_xla_gpu_enable_latency_hiding_scheduler (true );
6465 debug_options.set_xla_gpu_collective_permute_decomposer_threshold (0 );
66+ debug_options.set_xla_gpu_autotune_level (0 );
6567 config.set_debug_options (debug_options);
6668
6769 return config;
@@ -1239,7 +1241,7 @@ XLA_TEST_P(CollectivePipelineParallelismTest,
12391241 frontend_attributes={_xla_send_recv_source_target_pairs={{0,1}}},
12401242 channel_id=1
12411243 recv_ctx_ = (f32[2,2], u32[], token[]) recv(after_all),
1242- frontend_attributes={_xla_send_recv_source_target_pairs={{0,1}}},
1244+ frontend_attributes={_xla_send_recv_source_target_pairs={{0,1}}},
12431245 channel_id=2
12441246 init = (u32[], (f32[2,2], u32[], token[]), (f32[2,2], u32[], token[]))
12451247 tuple(i, send_ctx_, recv_ctx_)
@@ -1717,6 +1719,338 @@ XLA_TEST_P(CollectivePipelineParallelismTest,
17171719 ErrorSpec{/* abs_error=*/ 1e-5 , /* rel_error=*/ 1e-5 }));
17181720}
17191721
1722+ XLA_TEST_P (CollectivePipelineParallelismTest, JaxExampleWithDecomposedCycle) {
1723+ constexpr char kModuleStr [] = R"(
1724+ HloModule jit_entry_computation, entry_computation_layout={
1725+ (f32[4,4096,4096]{2,1,0}, f32[4,5,4096,8192]{3,2,1,0})->
1726+ f32[4,5,4096,8192]{3,2,1,0}},
1727+ allow_spmd_sharding_propagation_to_parameters={false,false},
1728+ allow_spmd_sharding_propagation_to_output={true}, num_partitions=4
1729+
1730+ %_where.10 (Arg_0.11: pred[], Arg_1.12: s32[], Arg_2.13: s32[]) -> s32[] {
1731+ %Arg_0.11 = pred[] parameter(0)
1732+ %Arg_1.12 = s32[] parameter(1)
1733+ %Arg_2.13 = s32[] parameter(2)
1734+ ROOT %select.14 = s32[] select(%Arg_0.11, %Arg_1.12, %Arg_2.13)
1735+ }
1736+
1737+ %remainder.15 (Arg_0.16: s32[], Arg_1.17: s32[]) -> s32[] {
1738+ %Arg_0.16 = s32[] parameter(0)
1739+ %Arg_1.17 = s32[] parameter(1)
1740+ %constant.19 = s32[] constant(0)
1741+ %compare.20 = pred[] compare(%Arg_1.17, %constant.19), direction=EQ
1742+ %constant.18 = s32[] constant(1)
1743+ %call.21 = s32[] call(%compare.20, %constant.18, %Arg_1.17),
1744+ to_apply=%_where.10
1745+ %remainder.22 = s32[] remainder(%Arg_0.16, %call.21)
1746+ %compare.24 = pred[] compare(%remainder.22, %constant.19), direction=LT
1747+ %compare.25 = pred[] compare(%call.21, %constant.19), direction=LT
1748+ %compare.26 = pred[] compare(%compare.24, %compare.25), direction=NE
1749+ %compare.23 = pred[] compare(%remainder.22, %constant.19), direction=NE
1750+ %and.27 = pred[] and(%compare.26, %compare.23)
1751+ %add.28 = s32[] add(%remainder.22, %call.21)
1752+ ROOT %select.29 = s32[] select(%and.27, %add.28, %remainder.22)
1753+ }
1754+
1755+ %_pad.30 (Arg_0.31: f32[4,1,4096,8192], Arg_1.32: s32[]) -> f32[5,1,4096,8192] {
1756+ %Arg_0.31 = f32[4,1,4096,8192]{3,2,1,0} parameter(0)
1757+ %Arg_1.32 = s32[] parameter(1)
1758+ %convert.33 = f32[] convert(%Arg_1.32)
1759+ ROOT %pad.34 = f32[5,1,4096,8192]{3,2,1,0} pad(%Arg_0.31, %convert.33),
1760+ padding=1_0x0_0x0_0x0_0
1761+ }
1762+
1763+ %_where_0.35 (Arg_0.36: pred[], Arg_1.37: f32[4,1,4096,8192],
1764+ Arg_2.38: f32[4,1,4096,8192]) -> f32[4,1,4096,8192] {
1765+ %Arg_0.36 = pred[] parameter(0)
1766+ %broadcast.39 = pred[4,1,4096,8192]{3,2,1,0} broadcast(%Arg_0.36),
1767+ dimensions={}
1768+ %Arg_1.37 = f32[4,1,4096,8192]{3,2,1,0} parameter(1)
1769+ %Arg_2.38 = f32[4,1,4096,8192]{3,2,1,0} parameter(2)
1770+ ROOT %select.40 = f32[4,1,4096,8192]{3,2,1,0} select(%broadcast.39, %Arg_1.37,
1771+ %Arg_2.38)
1772+ }
1773+
1774+ %_where_1.41 (Arg_0.42: pred[4,1,4096,8192], Arg_1.43: f32[4,1,4096,8192],
1775+ Arg_2.44: f32[4,1,4096,8192]) -> f32[4,1,4096,8192] {
1776+ %Arg_0.42 = pred[4,1,4096,8192]{3,2,1,0} parameter(0)
1777+ %Arg_1.43 = f32[4,1,4096,8192]{3,2,1,0} parameter(1)
1778+ %Arg_2.44 = f32[4,1,4096,8192]{3,2,1,0} parameter(2)
1779+ ROOT %select.45 = f32[4,1,4096,8192]{3,2,1,0} select(%Arg_0.42, %Arg_1.43,
1780+ %Arg_2.44)
1781+ }
1782+
1783+ %_where.46 (Arg_0.47: pred[], Arg_1.48: s32[], Arg_2.49: s32[]) -> s32[] {
1784+ %Arg_0.47 = pred[] parameter(0)
1785+ %Arg_1.48 = s32[] parameter(1)
1786+ %Arg_2.49 = s32[] parameter(2)
1787+ ROOT %select.50 = s32[] select(%Arg_0.47, %Arg_1.48, %Arg_2.49)
1788+ }
1789+
1790+ %remainder.51 (Arg_0.52: s32[], Arg_1.53: s32[]) -> s32[] {
1791+ %Arg_0.52 = s32[] parameter(0)
1792+ %Arg_1.53 = s32[] parameter(1)
1793+ %constant.55 = s32[] constant(0)
1794+ %compare.56 = pred[] compare(%Arg_1.53, %constant.55), direction=EQ
1795+ %constant.54 = s32[] constant(1)
1796+ %call.57 = s32[] call(%compare.56, %constant.54, %Arg_1.53),
1797+ to_apply=%_where.46
1798+ %remainder.58 = s32[] remainder(%Arg_0.52, %call.57)
1799+ %compare.60 = pred[] compare(%remainder.58, %constant.55), direction=LT
1800+ %compare.61 = pred[] compare(%call.57, %constant.55), direction=LT
1801+ %compare.62 = pred[] compare(%compare.60, %compare.61), direction=NE
1802+ %compare.59 = pred[] compare(%remainder.58, %constant.55), direction=NE
1803+ %and.63 = pred[] and(%compare.62, %compare.59)
1804+ %add.64 = s32[] add(%remainder.58, %call.57)
1805+ ROOT %select.65 = s32[] select(%and.63, %add.64, %remainder.58)
1806+ }
1807+
1808+ %_pad_2.66 (Arg_0.67: f32[4,1,4096,8192], Arg_1.68: s32[])
1809+ -> f32[7,1,4096,8192] {
1810+ %Arg_0.67 = f32[4,1,4096,8192]{3,2,1,0} parameter(0)
1811+ %Arg_1.68 = s32[] parameter(1)
1812+ %convert.69 = f32[] convert(%Arg_1.68)
1813+ ROOT %pad.70 = f32[7,1,4096,8192]{3,2,1,0} pad(%Arg_0.67, %convert.69),
1814+ padding=0_3x0_0x0_0x0_0
1815+ }
1816+
1817+ %_where.71 (Arg_0.72: pred[], Arg_1.73: s32[], Arg_2.74: s32[]) -> s32[] {
1818+ %Arg_0.72 = pred[] parameter(0)
1819+ %Arg_1.73 = s32[] parameter(1)
1820+ %Arg_2.74 = s32[] parameter(2)
1821+ ROOT %select.75 = s32[] select(%Arg_0.72, %Arg_1.73, %Arg_2.74)
1822+ }
1823+
1824+ %remainder.76 (Arg_0.77: s32[], Arg_1.78: s32[]) -> s32[] {
1825+ %Arg_0.77 = s32[] parameter(0)
1826+ %Arg_1.78 = s32[] parameter(1)
1827+ %constant.80 = s32[] constant(0)
1828+ %compare.81 = pred[] compare(%Arg_1.78, %constant.80), direction=EQ
1829+ %constant.79 = s32[] constant(1)
1830+ %call.82 = s32[] call(%compare.81, %constant.79, %Arg_1.78),
1831+ to_apply=%_where.71
1832+ %remainder.83 = s32[] remainder(%Arg_0.77, %call.82)
1833+ %compare.85 = pred[] compare(%remainder.83, %constant.80), direction=LT
1834+ %compare.86 = pred[] compare(%call.82, %constant.80), direction=LT
1835+ %compare.87 = pred[] compare(%compare.85, %compare.86), direction=NE
1836+ %compare.84 = pred[] compare(%remainder.83, %constant.80), direction=NE
1837+ %and.88 = pred[] and(%compare.87, %compare.84)
1838+ %add.89 = s32[] add(%remainder.83, %call.82)
1839+ ROOT %select.90 = s32[] select(%and.88, %add.89, %remainder.83)
1840+ }
1841+
1842+ %None.91 (Arg_0.92: f32[4,4096,4096], Arg_1.93: f32[4,5,4096,8192],
1843+ Arg_2.94: f32[4,5,4096,8192], Arg_3.95: f32[4,1,4096,8192],
1844+ Arg_4.96: f32[4,1,4096,8192], Arg_5.97: s32[])
1845+ -> (f32[4,4096,4096], f32[4,5,4096,8192], f32[4,5,4096,8192],
1846+ f32[4,1,4096,8192], f32[4,1,4096,8192]) {
1847+ %Arg_0.92 = f32[4,4096,4096]{2,1,0} parameter(0)
1848+ %Arg_1.93 = f32[4,5,4096,8192]{3,2,1,0} parameter(1)
1849+ %Arg_2.94 = f32[4,5,4096,8192]{3,2,1,0} parameter(2)
1850+ %iota.113 = s32[4]{0} iota(), iota_dimension=0
1851+ %broadcast.114 = s32[4,1,4096,8192]{3,2,1,0} broadcast(%iota.113),
1852+ dimensions={0}
1853+ %constant.98 = s32[] constant(0)
1854+ %broadcast.99 = s32[4,1,4096,8192]{3,2,1,0} broadcast(%constant.98),
1855+ dimensions={}
1856+ %compare.115 = pred[4,1,4096,8192]{3,2,1,0} compare(%broadcast.114,
1857+ %broadcast.99), direction=EQ
1858+ %Arg_5.97 = s32[] parameter(5)
1859+ %constant.102 = s32[] constant(5)
1860+ %compare.111 = pred[] compare(%Arg_5.97, %constant.102), direction=LT
1861+ %constant.103 = s32[] constant(0)
1862+ %call.104 = s32[] call(%Arg_5.97, %constant.102), to_apply=%remainder.15
1863+ %compare.105 = pred[] compare(%call.104, %constant.103), direction=LT
1864+ %add.106 = s32[] add(%call.104, %constant.102)
1865+ %select.107 = s32[] select(%compare.105, %add.106, %call.104)
1866+ %dynamic-slice.108 = f32[4,1,4096,8192]{3,2,1,0} dynamic-slice(%Arg_1.93,
1867+ %constant.103, %select.107, %constant.103, %constant.103),
1868+ dynamic_slice_sizes={4,1,4096,8192}
1869+ %Arg_4.96 = f32[4,1,4096,8192]{3,2,1,0} parameter(4)
1870+ %call.112 = f32[4,1,4096,8192]{3,2,1,0} call(%compare.111, %dynamic-slice.108,
1871+ %Arg_4.96), to_apply=%_where_0.35
1872+ %Arg_3.95 = f32[4,1,4096,8192]{3,2,1,0} parameter(3)
1873+ %call.109 = f32[5,1,4096,8192]{3,2,1,0} call(%Arg_3.95, %constant.103),
1874+ to_apply=%_pad.30
1875+ %slice.110 = f32[4,1,4096,8192]{3,2,1,0} slice(%call.109),
1876+ slice={[0:4], [0:1], [0:4096], [0:8192]}
1877+ %call.116 = f32[4,1,4096,8192]{3,2,1,0} call(%compare.115, %call.112,
1878+ %slice.110), to_apply=%_where_1.41
1879+ %reshape.117 = f32[4,4096,8192]{2,1,0} reshape(%call.116)
1880+ %dot.118 = f32[4,4096,8192]{2,1,0} dot(%Arg_0.92, %reshape.117),
1881+ lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0},
1882+ rhs_contracting_dims={1}
1883+ %reshape.150 = f32[4,1,4096,8192]{3,2,1,0} reshape(%dot.118)
1884+ %constant.100 = s32[] constant(2)
1885+ %add.159 = s32[] add(%Arg_5.97, %constant.100)
1886+ %call.160 = s32[] call(%add.159, %constant.102), to_apply=%remainder.76
1887+ %compare.161 = pred[] compare(%call.160, %constant.103), direction=LT
1888+ %add.162 = s32[] add(%call.160, %constant.102)
1889+ %select.163 = s32[] select(%compare.161, %add.162, %call.160)
1890+ %dynamic-update-slice.164 = f32[4,5,4096,8192]{3,2,1,0}
1891+ dynamic-update-slice(%Arg_2.94, %reshape.150, %constant.103, %select.163,
1892+ %constant.103, /*index=5*/%constant.103)
1893+ %constant.101 = s32[] constant(1)
1894+ %add.151 = s32[] add(%Arg_5.97, %constant.101)
1895+ %call.152 = s32[] call(%add.151, %constant.102), to_apply=%remainder.51
1896+ %compare.153 = pred[] compare(%call.152, %constant.103), direction=LT
1897+ %add.154 = s32[] add(%call.152, %constant.102)
1898+ %select.155 = s32[] select(%compare.153, %add.154, %call.152)
1899+ %dynamic-slice.156 = f32[4,1,4096,8192]{3,2,1,0} dynamic-slice(%Arg_2.94,
1900+ %constant.103, %select.155, %constant.103, %constant.103),
1901+ dynamic_slice_sizes={4,1,4096,8192}
1902+ %call.157 = f32[7,1,4096,8192]{3,2,1,0} call(%dynamic-slice.156,
1903+ %constant.103), to_apply=%_pad_2.66
1904+ %slice.158 = f32[4,1,4096,8192]{3,2,1,0} slice(%call.157),
1905+ slice={[3:7], [0:1], [0:4096], [0:8192]}
1906+ ROOT %tuple.165 = (f32[4,4096,4096]{2,1,0}, f32[4,5,4096,8192]{3,2,1,0},
1907+ f32[4,5,4096,8192]{3,2,1,0}, f32[4,1,4096,8192]{3,2,1,0},
1908+ f32[4,1,4096,8192]{3,2,1,0}) tuple(%Arg_0.92, %Arg_1.93,
1909+ %dynamic-update-slice.164, %reshape.150, %slice.158)
1910+ }
1911+
1912+ %region_0.166 (arg_tuple.167: (s32[], f32[4,4096,4096], f32[4,5,4096,8192],
1913+ f32[4,5,4096,8192], f32[4,1,4096,8192], /*index=5*/f32[4,1,4096,8192],
1914+ s32[13])) -> (s32[], f32[4,4096,4096], f32[4,5,4096,8192],
1915+ f32[4,5,4096,8192], f32[4,1,4096,8192], /*index=5*/f32[4,1,4096,8192],
1916+ s32[13]) {
1917+ %arg_tuple.167 = (s32[], f32[4,4096,4096]{2,1,0}, f32[4,5,4096,8192]{3,2,1,0},
1918+ f32[4,5,4096,8192]{3,2,1,0}, f32[4,1,4096,8192]{3,2,1,0},
1919+ /*index=5*/f32[4,1,4096,8192]{3,2,1,0}, s32[13]{0}) parameter(0)
1920+ %get-tuple-element.168 = s32[] get-tuple-element(%arg_tuple.167), index=0
1921+ %constant.175 = s32[] constant(1)
1922+ %add.184 = s32[] add(%get-tuple-element.168, %constant.175)
1923+ %get-tuple-element.169 = f32[4,4096,4096]{2,1,0}
1924+ get-tuple-element(%arg_tuple.167), index=1
1925+ %get-tuple-element.170 = f32[4,5,4096,8192]{3,2,1,0}
1926+ get-tuple-element(%arg_tuple.167), index=2
1927+ %get-tuple-element.171 = f32[4,5,4096,8192]{3,2,1,0}
1928+ get-tuple-element(%arg_tuple.167), index=3
1929+ %get-tuple-element.172 = f32[4,1,4096,8192]{3,2,1,0}
1930+ get-tuple-element(%arg_tuple.167), index=4
1931+ %get-tuple-element.173 = f32[4,1,4096,8192]{3,2,1,0}
1932+ get-tuple-element(%arg_tuple.167), index=5
1933+ %get-tuple-element.174 = s32[13]{0} get-tuple-element(%arg_tuple.167), index=6
1934+ %dynamic-slice.176 = s32[1]{0} dynamic-slice(%get-tuple-element.174,
1935+ %get-tuple-element.168), dynamic_slice_sizes={1}
1936+ %reshape.177 = s32[] reshape(%dynamic-slice.176)
1937+ %call.178 = (f32[4,4096,4096]{2,1,0}, f32[4,5,4096,8192]{3,2,1,0},
1938+ f32[4,5,4096,8192]{3,2,1,0}, f32[4,1,4096,8192]{3,2,1,0},
1939+ f32[4,1,4096,8192]{3,2,1,0}) call(%get-tuple-element.169,
1940+ %get-tuple-element.170, %get-tuple-element.171, %get-tuple-element.172,
1941+ %get-tuple-element.173, /*index=5*/%reshape.177), to_apply=%None.91
1942+ %get-tuple-element.179 = f32[4,4096,4096]{2,1,0} get-tuple-element(%call.178),
1943+ index=0
1944+ %get-tuple-element.180 = f32[4,5,4096,8192]{3,2,1,0}
1945+ get-tuple-element(%call.178), index=1
1946+ %get-tuple-element.181 = f32[4,5,4096,8192]{3,2,1,0}
1947+ get-tuple-element(%call.178), index=2
1948+ %get-tuple-element.182 = f32[4,1,4096,8192]{3,2,1,0}
1949+ get-tuple-element(%call.178), index=3
1950+ %get-tuple-element.183 = f32[4,1,4096,8192]{3,2,1,0}
1951+ get-tuple-element(%call.178), index=4
1952+ ROOT %tuple.185 = (s32[], f32[4,4096,4096]{2,1,0},
1953+ f32[4,5,4096,8192]{3,2,1,0}, f32[4,5,4096,8192]{3,2,1,0},
1954+ f32[4,1,4096,8192]{3,2,1,0}, /*index=5*/f32[4,1,4096,8192]{3,2,1,0},
1955+ s32[13]{0}) tuple(%add.184, %get-tuple-element.179,
1956+ %get-tuple-element.180, %get-tuple-element.181, %get-tuple-element.182,
1957+ /*index=5*/%get-tuple-element.183, %get-tuple-element.174)
1958+ }
1959+
1960+ %region_1.186 (arg_tuple.187: (s32[], f32[4,4096,4096], f32[4,5,4096,8192],
1961+ f32[4,5,4096,8192], f32[4,1,4096,8192], /*index=5*/f32[4,1,4096,8192],
1962+ s32[13])) -> pred[] {
1963+ %arg_tuple.187 = (s32[], f32[4,4096,4096]{2,1,0}, f32[4,5,4096,8192]{3,2,1,0},
1964+ f32[4,5,4096,8192]{3,2,1,0}, f32[4,1,4096,8192]{3,2,1,0},
1965+ /*index=5*/f32[4,1,4096,8192]{3,2,1,0}, s32[13]{0}) parameter(0)
1966+ %get-tuple-element.189 = f32[4,4096,4096]{2,1,0}
1967+ get-tuple-element(%arg_tuple.187), index=1
1968+ %get-tuple-element.190 = f32[4,5,4096,8192]{3,2,1,0}
1969+ get-tuple-element(%arg_tuple.187), index=2
1970+ %get-tuple-element.191 = f32[4,5,4096,8192]{3,2,1,0}
1971+ get-tuple-element(%arg_tuple.187), index=3
1972+ %get-tuple-element.192 = f32[4,1,4096,8192]{3,2,1,0}
1973+ get-tuple-element(%arg_tuple.187), index=4
1974+ %get-tuple-element.193 = f32[4,1,4096,8192]{3,2,1,0}
1975+ get-tuple-element(%arg_tuple.187), index=5
1976+ %get-tuple-element.194 = s32[13]{0} get-tuple-element(%arg_tuple.187), index=6
1977+ %get-tuple-element.188 = s32[] get-tuple-element(%arg_tuple.187), index=0
1978+ %constant.195 = s32[] constant(13)
1979+ ROOT %compare.196 = pred[] compare(%get-tuple-element.188, %constant.195),
1980+ direction=LT
1981+ }
1982+
1983+ ENTRY %main.204 (Arg_0.1: f32[4,4096,4096], Arg_1.2: f32[4,5,4096,8192])
1984+ -> f32[4,5,4096,8192] {
1985+ %constant.3 = s32[] constant(0)
1986+ %Arg_0.1 = f32[4,4096,4096]{2,1,0} parameter(0),
1987+ sharding={devices=[4,1,1]<=[4]}
1988+ %Arg_1.2 = f32[4,5,4096,8192]{3,2,1,0} parameter(1),
1989+ sharding={devices=[4,1,1,1]<=[4]}
1990+ %constant.4 = f32[] constant(0)
1991+ %broadcast.5 = f32[4,5,4096,8192]{3,2,1,0} broadcast(%constant.4),
1992+ dimensions={}
1993+ %constant.6 = f32[] constant(0)
1994+ %broadcast.7 = f32[4,1,4096,8192]{3,2,1,0} broadcast(%constant.6),
1995+ dimensions={}
1996+ %iota.8 = s32[13]{0} iota(), iota_dimension=0
1997+ %tuple.9 = (s32[], f32[4,4096,4096]{2,1,0}, f32[4,5,4096,8192]{3,2,1,0},
1998+ f32[4,5,4096,8192]{3,2,1,0}, f32[4,1,4096,8192]{3,2,1,0},
1999+ /*index=5*/f32[4,1,4096,8192]{3,2,1,0}, s32[13]{0}) tuple(%constant.3,
2000+ %Arg_0.1, %Arg_1.2, %broadcast.5, %broadcast.7, /*index=5*/%broadcast.7,
2001+ %iota.8)
2002+ %while.197 = (s32[], f32[4,4096,4096]{2,1,0}, f32[4,5,4096,8192]{3,2,1,0},
2003+ f32[4,5,4096,8192]{3,2,1,0}, f32[4,1,4096,8192]{3,2,1,0},
2004+ /*index=5*/f32[4,1,4096,8192]{3,2,1,0}, s32[13]{0}) while(%tuple.9),
2005+ condition=%region_1.186, body=%region_0.166
2006+ %get-tuple-element.198 = s32[] get-tuple-element(%while.197), index=0
2007+ %get-tuple-element.199 = f32[4,4096,4096]{2,1,0}
2008+ get-tuple-element(%while.197), index=1
2009+ %get-tuple-element.200 = f32[4,5,4096,8192]{3,2,1,0}
2010+ get-tuple-element(%while.197), index=2
2011+ ROOT %get-tuple-element.201 = f32[4,5,4096,8192]{3,2,1,0}
2012+ get-tuple-element(%while.197), index=3
2013+ %get-tuple-element.202 = f32[4,1,4096,8192]{3,2,1,0}
2014+ get-tuple-element(%while.197), index=4
2015+ %get-tuple-element.203 = f32[4,1,4096,8192]{3,2,1,0}
2016+ get-tuple-element(%while.197), index=5
2017+ }
2018+ )" ;
2019+
2020+ const int64_t kNumReplicas = 1 ;
2021+ const int64_t kNumPartitions = 4 ;
2022+ if (test_runner ().device_count () < kNumReplicas * kNumPartitions ) {
2023+ GTEST_SKIP () << " Test requires at least " << kNumReplicas * kNumPartitions
2024+ << " devices (" << test_runner ().device_count ()
2025+ << " available)" ;
2026+ }
2027+
2028+ HloModuleConfig config = GetModuleConfigForTest (
2029+ /* replica_count=*/ kNumReplicas , /* num_partitions=*/ kNumPartitions );
2030+ TF_ASSERT_OK_AND_ASSIGN (auto module ,
2031+ ParseAndReturnVerifiedModule (kModuleStr , config));
2032+
2033+ // Create device assignment running across partitions.
2034+ DeviceAssignment device_assignment (/* replica_count=*/ kNumReplicas ,
2035+ /* computation_count=*/ kNumPartitions );
2036+ for (int64_t i = 0 ; i < kNumPartitions ; ++i) {
2037+ device_assignment (0 , i) = i;
2038+ }
2039+
2040+ TF_ASSERT_OK_AND_ASSIGN (std::vector<Literal> fake_args,
2041+ MakeFakeArguments (module .get ()));
2042+ std::vector<Literal *> args;
2043+ for (auto &arg : fake_args) {
2044+ args.push_back (&arg);
2045+ }
2046+ TF_ASSERT_OK_AND_ASSIGN (
2047+ std::vector<Literal> results,
2048+ ExecuteReplicated (std::move (module ), args,
2049+ /* num_replicas=*/ kNumPartitions , &device_assignment,
2050+ /* run_hlo_passes=*/ true , /* use_threads=*/ true ));
2051+ ASSERT_EQ (results.size (), kNumPartitions );
2052+ }
2053+
17202054INSTANTIATE_TEST_SUITE_P (
17212055 CollectivePipelineParallelismTestWithAndWithoutOpts,
17222056 CollectivePipelineParallelismTest,
0 commit comments