1111#include " tuner/nccl_ofi_tuner_region.h"
1212#include " nccl_ofi_param.h"
1313
14- using std::abs;
15- using std::log2;
16- using std::pow;
17- const double eps = 1e-4 ;
18-
1914static int test_extend_region (void )
2015{
2116 nccl_ofi_tuner_point_t extended_point;
@@ -26,9 +21,8 @@ static int test_extend_region(void)
2621 extended_point = extend_region ((nccl_ofi_tuner_point_t ){2 , 8 },
2722 (nccl_ofi_tuner_point_t ){4 , 8 },
2823 (nccl_ofi_tuner_point_t ){TUNER_MAX_SIZE, TUNER_MAX_RANKS});
29- if (abs (extended_point.x - TUNER_MAX_SIZE) > eps || extended_point.y != 8 ) {
30- printf (" X-Axis Extend Test Failed : Extended Points : x = %f (diff = %f) y = %f\n " , extended_point.x ,
31- extended_point.x - TUNER_MAX_SIZE, extended_point.y );
24+ if (extended_point.x != TUNER_MAX_SIZE || extended_point.y != 8 ) {
25+ printf (" X-Axis Extend Test Failed : Extended Points : x = %f y = %f\n " , extended_point.x , extended_point.y );
3226 return -1 ;
3327 }
3428
@@ -45,26 +39,22 @@ static int test_extend_region(void)
4539 extended_point = extend_region ((nccl_ofi_tuner_point_t ){8 , 64 },
4640 (nccl_ofi_tuner_point_t ){8290304 , 72 },
4741 (nccl_ofi_tuner_point_t ){TUNER_MAX_SIZE, TUNER_MAX_RANKS});
48- slope = (log2 ( 72.0 ) - log2 ( 64.0 )) / (log2 ( 8290304.0 ) - log2 ( 8.0 ) ); // slope = (y2 - y1)/(x2 - x1)
42+ slope = (72.0 - 64.0 ) / (8290304.0 - 8.0 ); // slope = (y2 - y1)/(x2 - x1)
4943 // y3 = mx3 + c and substitute for m=(y2-y1)/(x2-x1) and c = y2 - mx2
50- projected_y = pow (2.0 , log2 (72.0 ) + slope * (log2 (TUNER_MAX_SIZE) - log2 (8290304.0 ))); // y3 = y2 + mx3 - mx2
51- if (abs (extended_point.x - TUNER_MAX_SIZE) > eps || extended_point.y != projected_y) {
52- printf (" X-Axis Upper Bound Test Failed : Extended Points : x = %f (diff = %f) y = %f (diff = %f) \n " ,
53- extended_point.x , extended_point.x - TUNER_MAX_SIZE,
54- extended_point.y , extended_point.y - projected_y);
44+ projected_y = 72.0 + slope * (TUNER_MAX_SIZE - 8290304.0 ); // y3 = y2 + mx3 - mx2
45+ if (extended_point.x != TUNER_MAX_SIZE || extended_point.y != projected_y) {
46+ printf (" X-Axis Upper Bound Test Failed : Extended Points : x = %f y = %f\n " , extended_point.x , extended_point.y );
5547 return -1 ;
5648 }
5749
5850 /* Extend the line to TUNER_MAX_RANKS (y-axis) */
5951 extended_point = extend_region ((nccl_ofi_tuner_point_t ){8 , 64 },
6052 (nccl_ofi_tuner_point_t ){16 , 1024 },
6153 (nccl_ofi_tuner_point_t ){TUNER_MAX_SIZE, TUNER_MAX_RANKS});
62- slope = (log2 (1024.0 ) - log2 (64.0 )) / (log2 (16 ) - log2 (8.0 ));
63- projected_x = pow (2 , ((log2 (TUNER_MAX_RANKS) - log2 (1024.0 )) / slope) + log2 (16 ));
64- if (abs (extended_point.x - projected_x) > eps || extended_point.y != TUNER_MAX_RANKS) {
65- printf (" X-Axis Upper Bound Test 2 Failed : Extended Points : x = %f (diff = %f) y = %f (diff = %f) \n " ,
66- extended_point.x , extended_point.x - projected_x,
67- extended_point.y , extended_point.y - TUNER_MAX_RANKS);
54+ slope = (1024.0 - 64.0 ) / (16.0 - 8.0 );
55+ projected_x = ((TUNER_MAX_RANKS - 1024.0 ) / slope) + 16 ;
56+ if (extended_point.x != projected_x || extended_point.y != TUNER_MAX_RANKS) {
57+ printf (" X-Axis Upper Bound Test Failed : Extended Points : x = %f y = %f\n " , extended_point.x , extended_point.y );
6858 return -1 ;
6959 }
7060
@@ -85,7 +75,7 @@ static int test_extend_region(void)
8575| . |
8676| . |
8777|--------*------*----------*----------*---------*--------*---
88- | p3(4M, 2) |p5 (TUNER_MAX_SIZE, 2))
78+ | p3(4M, 2) |p4 (TUNER_MAX_SIZE, 2))
8979| |
9080*/
9181static int test_is_inside_region (void ) {
@@ -97,14 +87,6 @@ static int test_is_inside_region(void) {
9787 (nccl_ofi_tuner_point_t ){(double )48.0 * 1024 * 1024 , 16 },
9888 (nccl_ofi_tuner_point_t ){(double )288.0 * 1024 * 1024 , 128 },
9989 (nccl_ofi_tuner_point_t ){TUNER_MAX_SIZE, TUNER_MAX_RANKS});
100- printf (" INFO extended point: %f %f \n " , e_48M_16_288M_128.x , e_48M_16_288M_128.y );
101-
102- p1_288M_128.transform_log2 ();
103- p2_38M_16.transform_log2 ();
104- p3_4M_2.transform_log2 ();
105- p5_maxM_2.transform_log2 ();
106- e_48M_16_288M_128.transform_log2 ();
107- printf (" INFO extended point after transform_log2: %f %f \n " , e_48M_16_288M_128.x , e_48M_16_288M_128.y );
10890
10991 nccl_ofi_tuner_region_t region = {
11092 .algorithm = NCCL_ALGO_RING,
@@ -134,17 +116,15 @@ static int test_is_inside_region(void) {
134116 To find the points on the edge of the polygons:
135117 1. Consider two vertices of the polygon
136118 2. Calculate the slope and y-intercept of the line.
137- 3. Using the equation y = m * x + c, get multiple points on the line in powers of 2.
119+ 3. Using the equation y = mx + c, get multiple points on the line in powers of 2.
138120 */
139121 for (size_t i = 0 ; i < region.num_vertices ; i++) {
140122 size_t k = (i + 1 ) % region.num_vertices ;
141123 double slope = (region.vertices [k].y - region.vertices [i].y ) / (region.vertices [k].x - region.vertices [i].x );
142124 double c = region.vertices [k].y - (slope * (region.vertices [i].x ));
143125 for (double x = region.vertices [i].x ; x < region.vertices [k].x ; x = x * 2 ) {
144126 double y = (slope * x) + c;
145- nccl_ofi_tuner_point_t test_point {x, y, nccl_ofi_tuner_point_t ::LOG2};
146-
147- if (is_inside_region (test_point, ®ion) != 0 )
127+ if (is_inside_region ((nccl_ofi_tuner_point_t ){x, y}, ®ion) != 0 )
148128 return -1 ;
149129 // printf(" Is (%.10f, %.10f) inside the region : %d\n", x, y, is_inside_region(
150130 // (nccl_ofi_tuner_point_t){x, y}, ®ion));
@@ -153,8 +133,8 @@ static int test_is_inside_region(void) {
153133
154134 printf (" All points on the edges of the polygon are detected correcltly\n " );
155135
156- const size_t num_points = 20 ;
157- nccl_ofi_tuner_point_t inside_vertices[] = {{16.0 * 1024 * 1024 , 4 },
136+ size_t num_points = 20 ;
137+ const nccl_ofi_tuner_point_t inside_vertices[] = {{16.0 * 1024 * 1024 , 4 },
158138 {128.0 * 1024 * 1024 , 4 },
159139 {1.0 * 1024 * 1024 * 1024 , 4 },
160140 {4.0 * 1024 * 1024 * 1024 , 4 },
@@ -172,24 +152,19 @@ static int test_is_inside_region(void) {
172152 {32.0 * 1024 * 1024 * 1024 , 128 },
173153 {64.0 * 1024 * 1024 * 1024 , 128 },
174154 {64.0 * 1024 * 1024 * 1024 , 256 },
175- // Note, set a big enough diff (10.0) below, otherwise
176- // the delta after log2 is within floating error (eps).
177- {TUNER_MAX_SIZE - 10.0 , 128 },
178- {e_48M_16_288M_128.x - 1.0 , e_48M_16_288M_128.y - 10.0 , nccl_ofi_tuner_point_t ::LOG2}};
155+ {TUNER_MAX_SIZE - 1.0 , 128 },
156+ {e_48M_16_288M_128.x - 1.0 , e_48M_16_288M_128.y - 1.0 }};
179157
180158 /* These points should be inside the polygon */
181159 for (size_t i = 0 ; i < num_points; i++) {
182- inside_vertices[i].transform_log2 ();
183- int d = is_inside_region (inside_vertices[i], ®ion);
184- if (d != 1 ) {
185- printf (" %ld: %.10f, %.10f is_inside_region: %d\n " , i, inside_vertices[i].x , inside_vertices[i].y , d);
160+ if (is_inside_region (inside_vertices[i], ®ion) != 1 ) {
161+ printf (" %.10f, %.10f\n " , inside_vertices[i].x , inside_vertices[i].y );
186162 return -1 ;
187163 };
188164 }
189165
190166 printf (" All points inside the polygon are detected correcltly\n " );
191167
192- const size_t outside_num_points = 24 ;
193168 const nccl_ofi_tuner_point_t outside_vertices[] = {{8.0 * 1024 * 1024 , 4 },
194169 {8.0 * 1024 * 1024 , 32 },
195170 {8.0 * 1024 * 1024 , 128 },
@@ -216,10 +191,9 @@ static int test_is_inside_region(void) {
216191 {e_48M_16_288M_128.x + 1.0 , e_48M_16_288M_128.y + 1.0 }};
217192
218193 /* These points should be outside the polygons */
219- for (size_t i = 0 ; i < outside_num_points; i++) {
220- int d = is_inside_region (outside_vertices[i], ®ion);
221- if ( d != -1 ) {
222- printf (" %ld: %.10f, %.10f is_inside_region: %d\n " , i, outside_vertices[i].x , outside_vertices[i].y , d);
194+ for (size_t i = 0 ; i < num_points; i++) {
195+ if (is_inside_region (outside_vertices[i], ®ion) != -1 ) {
196+ printf (" %.10f, %.10f\n " , outside_vertices[i].x , outside_vertices[i].y );
223197 return -1 ;
224198 };
225199 }
@@ -240,17 +214,13 @@ static int test_is_inside_region(void) {
240214}
241215
242216int main (int argc, const char **argv) {
243- int ret = 0 ;
244- if ((ret |= test_extend_region ()) < 0 ) {
217+ if (test_extend_region () < 0 ) {
245218 printf (" Extend Region test failed\n " );
246219 };
247220
248- if ((ret |= test_is_inside_region () ) < 0 ) {
221+ if (test_is_inside_region () < 0 ) {
249222 printf (" Is inside region function failed\n " );
250223 }
251224
252- if (ret == 0 ) {
253- printf (" All tests passed.\n " );
254- }
255- return ret;
225+ return 0 ;
256226}
0 commit comments