2020package org .apache .iotdb .ainode .it ;
2121
2222import org .apache .iotdb .it .env .EnvFactory ;
23+ import org .apache .iotdb .it .framework .IoTDBTestRunner ;
24+ import org .apache .iotdb .itbase .category .AIClusterIT ;
2325import org .apache .iotdb .itbase .env .BaseEnv ;
2426
25- import com .google .common .collect .ImmutableMap ;
2627import com .google .common .collect .ImmutableSet ;
2728import org .junit .AfterClass ;
2829import org .junit .Assert ;
2930import org .junit .BeforeClass ;
3031import org .junit .Test ;
32+ import org .junit .experimental .categories .Category ;
33+ import org .junit .runner .RunWith ;
3134import org .slf4j .Logger ;
3235import org .slf4j .LoggerFactory ;
3336
3639import java .sql .SQLException ;
3740import java .sql .Statement ;
3841import java .util .HashSet ;
39- import java .util .Map ;
4042import java .util .Set ;
4143import java .util .concurrent .TimeUnit ;
4244
4345import static org .apache .iotdb .ainode .utils .AINodeTestUtils .concurrentInference ;
4446
47+ @ RunWith (IoTDBTestRunner .class )
48+ @ Category ({AIClusterIT .class })
4549public class AINodeConcurrentInferenceIT {
4650
4751 private static final Logger LOGGER = LoggerFactory .getLogger (AINodeConcurrentInferenceIT .class );
4852
49- private static final Map <String , String > MODEL_ID_TO_TYPE_MAP =
50- ImmutableMap .of (
51- "timer_xl" , "Timer-XL" ,
52- "sundial" , "Timer-Sundial" );
53-
5453 @ BeforeClass
5554 public static void setUp () throws Exception {
5655 // Init 1C1D1A cluster environment
@@ -86,13 +85,12 @@ private static void prepareDataForTableModel() throws SQLException {
8685 for (int i = 0 ; i < 2880 ; i ++) {
8786 statement .execute (
8887 String .format (
89- "INSERT INTO root.AI(timestamp, s) VALUES(%d, %f)" ,
90- i , Math .sin (i * Math .PI / 1440 )));
88+ "INSERT INTO root.AI(time, s) VALUES(%d, %f)" , i , Math .sin (i * Math .PI / 1440 )));
9189 }
9290 }
9391 }
9492
95- @ Test
93+ // @Test
9694 public void concurrentCPUCallInferenceTest () throws SQLException , InterruptedException {
9795 concurrentCPUCallInferenceTest ("timer_xl" );
9896 concurrentCPUCallInferenceTest ("sundial" );
@@ -105,21 +103,21 @@ private void concurrentCPUCallInferenceTest(String modelId)
105103 final int threadCnt = 4 ;
106104 final int loop = 10 ;
107105 final int predictLength = 96 ;
108- statement .execute (String .format ("LOAD MODEL %s TO DEVICES \" cpu\" " , modelId ));
109- checkModelOnSpecifiedDevice (statement , MODEL_ID_TO_TYPE_MAP . get ( modelId ) , "cpu" );
106+ statement .execute (String .format ("LOAD MODEL %s TO DEVICES ' cpu' " , modelId ));
107+ checkModelOnSpecifiedDevice (statement , modelId , "cpu" );
110108 concurrentInference (
111109 statement ,
112110 String .format (
113- "CALL INFERENCE(%s, \" SELECT s FROM root.AI\" , predict_length=%d)" ,
111+ "CALL INFERENCE(%s, ' SELECT s FROM root.AI' , predict_length=%d)" ,
114112 modelId , predictLength ),
115113 threadCnt ,
116114 loop ,
117115 predictLength );
118- statement .execute (String .format ("UNLOAD MODEL %s FROM DEVICES \" cpu\" " , modelId ));
116+ statement .execute (String .format ("UNLOAD MODEL %s FROM DEVICES ' cpu' " , modelId ));
119117 }
120118 }
121119
122- @ Test
120+ // @Test
123121 public void concurrentGPUCallInferenceTest () throws SQLException , InterruptedException {
124122 concurrentGPUCallInferenceTest ("timer_xl" );
125123 concurrentGPUCallInferenceTest ("sundial" );
@@ -133,17 +131,17 @@ private void concurrentGPUCallInferenceTest(String modelId)
133131 final int loop = 100 ;
134132 final int predictLength = 512 ;
135133 final String devices = "0,1" ;
136- statement .execute (String .format ("LOAD MODEL %s TO DEVICES \" %s \" " , modelId , devices ));
137- checkModelOnSpecifiedDevice (statement , MODEL_ID_TO_TYPE_MAP . get ( modelId ) , devices );
134+ statement .execute (String .format ("LOAD MODEL %s TO DEVICES '%s' " , modelId , devices ));
135+ checkModelOnSpecifiedDevice (statement , modelId , devices );
138136 concurrentInference (
139137 statement ,
140138 String .format (
141- "CALL INFERENCE(%s, \" SELECT s FROM root.AI\" , predict_length=%d)" ,
139+ "CALL INFERENCE(%s, ' SELECT s FROM root.AI' , predict_length=%d)" ,
142140 modelId , predictLength ),
143141 threadCnt ,
144142 loop ,
145143 predictLength );
146- statement .execute (String .format ("UNLOAD MODEL %s FROM DEVICES \" 0,1\" " , modelId ));
144+ statement .execute (String .format ("UNLOAD MODEL %s FROM DEVICES ' 0,1' " , modelId ));
147145 }
148146 }
149147
@@ -159,8 +157,8 @@ private void concurrentCPUForecastTest(String modelId) throws SQLException, Inte
159157 final int threadCnt = 4 ;
160158 final int loop = 10 ;
161159 final int predictLength = 96 ;
162- statement .execute (String .format ("LOAD MODEL %s TO DEVICES \" cpu\" " , modelId ));
163- checkModelOnSpecifiedDevice (statement , MODEL_ID_TO_TYPE_MAP . get ( modelId ) , "cpu" );
160+ statement .execute (String .format ("LOAD MODEL %s TO DEVICES ' cpu' " , modelId ));
161+ checkModelOnSpecifiedDevice (statement , modelId , "cpu" );
164162 long startTime = System .currentTimeMillis ();
165163 concurrentInference (
166164 statement ,
@@ -175,7 +173,7 @@ private void concurrentCPUForecastTest(String modelId) throws SQLException, Inte
175173 String .format (
176174 "Model %s concurrent inference %d reqs (%d threads, %d loops) in CPU takes time: %dms" ,
177175 modelId , threadCnt * loop , threadCnt , loop , endTime - startTime ));
178- statement .execute (String .format ("UNLOAD MODEL %s FROM DEVICES \" cpu\" " , modelId ));
176+ statement .execute (String .format ("UNLOAD MODEL %s FROM DEVICES ' cpu' " , modelId ));
179177 }
180178 }
181179
@@ -192,8 +190,8 @@ public void concurrentGPUForecastTest(String modelId) throws SQLException, Inter
192190 final int loop = 100 ;
193191 final int predictLength = 512 ;
194192 final String devices = "0,1" ;
195- statement .execute (String .format ("LOAD MODEL %s TO DEVICES \" %s \" " , modelId , devices ));
196- checkModelOnSpecifiedDevice (statement , MODEL_ID_TO_TYPE_MAP . get ( modelId ) , devices );
193+ statement .execute (String .format ("LOAD MODEL %s TO DEVICES '%s' " , modelId , devices ));
194+ checkModelOnSpecifiedDevice (statement , modelId , devices );
197195 long startTime = System .currentTimeMillis ();
198196 concurrentInference (
199197 statement ,
@@ -208,32 +206,35 @@ public void concurrentGPUForecastTest(String modelId) throws SQLException, Inter
208206 String .format (
209207 "Model %s concurrent inference %d reqs (%d threads, %d loops) in GPU takes time: %dms" ,
210208 modelId , threadCnt * loop , threadCnt , loop , endTime - startTime ));
211- statement .execute (String .format ("UNLOAD MODEL %s FROM DEVICES \" 0,1\" " , modelId ));
209+ statement .execute (String .format ("UNLOAD MODEL %s FROM DEVICES ' 0,1' " , modelId ));
212210 }
213211 }
214212
215- private void checkModelOnSpecifiedDevice (Statement statement , String modelType , String device )
213+ private void checkModelOnSpecifiedDevice (Statement statement , String modelId , String device )
216214 throws SQLException , InterruptedException {
217- for (int retry = 0 ; retry < 10 ; retry ++) {
218- Set <String > targetDevices = ImmutableSet .copyOf (device .split ("," ));
215+ Set <String > targetDevices = ImmutableSet .copyOf (device .split ("," ));
216+ LOGGER .info ("Checking model: {} on target devices: {}" , modelId , targetDevices );
217+ for (int retry = 0 ; retry < 20 ; retry ++) {
219218 Set <String > foundDevices = new HashSet <>();
220219 try (final ResultSet resultSet =
221- statement .executeQuery (String .format ("SHOW LOADED MODELS %s " , device ))) {
220+ statement .executeQuery (String .format ("SHOW LOADED MODELS '%s' " , device ))) {
222221 while (resultSet .next ()) {
223- String deviceId = resultSet .getString (1 );
224- String loadedModelType = resultSet .getString (2 );
225- int count = resultSet .getInt (3 );
226- if ( loadedModelType . equals ( modelType ) && targetDevices . contains ( deviceId )) {
227- Assert . assertTrue ( count > 1 );
222+ String deviceId = resultSet .getString ("DeviceId" );
223+ String loadedModelId = resultSet .getString ("ModelId" );
224+ int count = resultSet .getInt ("Count(instances)" );
225+ LOGGER . info ( "Model {} found in device {}, count {}" , loadedModelId , deviceId , count );
226+ if ( loadedModelId . equals ( modelId ) && targetDevices . contains ( deviceId ) && count > 0 ) {
228227 foundDevices .add (deviceId );
228+ LOGGER .info ("Model {} is loaded to device {}" , modelId , device );
229229 }
230230 }
231231 if (foundDevices .containsAll (targetDevices )) {
232+ LOGGER .info ("Model {} is loaded to devices {}, start testing" , modelId , targetDevices );
232233 return ;
233234 }
234235 }
235236 TimeUnit .SECONDS .sleep (3 );
236237 }
237- Assert .fail ("Model " + modelType + " is not loaded on device " + device );
238+ Assert .fail ("Model " + modelId + " is not loaded on device " + device );
238239 }
239240}
0 commit comments