2222import org .apache .iotdb .it .env .EnvFactory ;
2323import org .apache .iotdb .itbase .env .BaseEnv ;
2424
25+ import com .google .common .collect .ImmutableMap ;
26+ import com .google .common .collect .ImmutableSet ;
2527import org .junit .AfterClass ;
28+ import org .junit .Assert ;
2629import org .junit .BeforeClass ;
2730import org .junit .Test ;
2831import org .slf4j .Logger ;
2932import org .slf4j .LoggerFactory ;
3033
3134import java .sql .Connection ;
35+ import java .sql .ResultSet ;
3236import java .sql .SQLException ;
3337import java .sql .Statement ;
38+ import java .util .HashSet ;
39+ import java .util .Map ;
40+ import java .util .Set ;
41+ import java .util .concurrent .TimeUnit ;
3442
3543import static org .apache .iotdb .ainode .utils .AINodeTestUtils .concurrentInference ;
3644
3745public class AINodeConcurrentInferenceIT {
3846
3947 private static final Logger LOGGER = LoggerFactory .getLogger (AINodeConcurrentInferenceIT .class );
4048
49+ private static final Map <String , String > MODEL_ID_TO_TYPE_MAP =
50+ ImmutableMap .of (
51+ "timer_xl" , "Timer-XL" ,
52+ "sundial" , "Timer-Sundial" );
53+
4154 @ BeforeClass
4255 public static void setUp () throws Exception {
4356 // Init 1C1D1A cluster environment
@@ -91,12 +104,17 @@ private void concurrentCPUCallInferenceTest(String modelId)
91104 Statement statement = connection .createStatement ()) {
92105 final int threadCnt = 4 ;
93106 final int loop = 10 ;
107+ final int predictLength = 96 ;
94108 statement .execute (String .format ("LOAD MODEL %s TO DEVICES \" cpu\" " , modelId ));
109+ checkModelOnSpecifiedDevice (statement , MODEL_ID_TO_TYPE_MAP .get (modelId ), "cpu" );
95110 concurrentInference (
96111 statement ,
97- String .format ("CALL INFERENCE(%s, \" SELECT s FROM root.AI\" )" , modelId ),
112+ String .format (
113+ "CALL INFERENCE(%s, \" SELECT s FROM root.AI\" , predict_length=%d)" ,
114+ modelId , predictLength ),
98115 threadCnt ,
99- loop );
116+ loop ,
117+ predictLength );
100118 statement .execute (String .format ("UNLOAD MODEL %s FROM DEVICES \" cpu\" " , modelId ));
101119 }
102120 }
@@ -111,14 +129,20 @@ private void concurrentGPUCallInferenceTest(String modelId)
111129 throws SQLException , InterruptedException {
112130 try (Connection connection = EnvFactory .getEnv ().getConnection (BaseEnv .TREE_SQL_DIALECT );
113131 Statement statement = connection .createStatement ()) {
114- final int threadCnt = 4 ;
115- final int loop = 10 ;
116- statement .execute (String .format ("LOAD MODEL %s TO DEVICES \" 0,1\" " , modelId ));
132+ final int threadCnt = 10 ;
133+ final int loop = 100 ;
134+ final int predictLength = 512 ;
135+ final String devices = "0,1" ;
136+ statement .execute (String .format ("LOAD MODEL %s TO DEVICES \" %s\" " , modelId , devices ));
137+ checkModelOnSpecifiedDevice (statement , MODEL_ID_TO_TYPE_MAP .get (modelId ), devices );
117138 concurrentInference (
118139 statement ,
119- String .format ("CALL INFERENCE(%s, \" SELECT s FROM root.AI\" )" , modelId ),
140+ String .format (
141+ "CALL INFERENCE(%s, \" SELECT s FROM root.AI\" , predict_length=%d)" ,
142+ modelId , predictLength ),
120143 threadCnt ,
121- loop );
144+ loop ,
145+ predictLength );
122146 statement .execute (String .format ("UNLOAD MODEL %s FROM DEVICES \" 0,1\" " , modelId ));
123147 }
124148 }
@@ -134,15 +158,18 @@ private void concurrentCPUForecastTest(String modelId) throws SQLException, Inte
134158 Statement statement = connection .createStatement ()) {
135159 final int threadCnt = 4 ;
136160 final int loop = 10 ;
161+ final int predictLength = 96 ;
137162 statement .execute (String .format ("LOAD MODEL %s TO DEVICES \" cpu\" " , modelId ));
163+ checkModelOnSpecifiedDevice (statement , MODEL_ID_TO_TYPE_MAP .get (modelId ), "cpu" );
138164 long startTime = System .currentTimeMillis ();
139165 concurrentInference (
140166 statement ,
141167 String .format (
142- "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time)" ,
143- modelId ),
168+ "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time), predict_length=>%d " ,
169+ modelId , predictLength ),
144170 threadCnt ,
145- loop );
171+ loop ,
172+ predictLength );
146173 long endTime = System .currentTimeMillis ();
147174 LOGGER .info (
148175 String .format (
@@ -163,15 +190,19 @@ public void concurrentGPUForecastTest(String modelId) throws SQLException, Inter
163190 Statement statement = connection .createStatement ()) {
164191 final int threadCnt = 10 ;
165192 final int loop = 100 ;
166- statement .execute (String .format ("LOAD MODEL %s TO DEVICES \" 0,1\" " , modelId ));
193+ final int predictLength = 512 ;
194+ final String devices = "0,1" ;
195+ statement .execute (String .format ("LOAD MODEL %s TO DEVICES \" %s\" " , modelId , devices ));
196+ checkModelOnSpecifiedDevice (statement , MODEL_ID_TO_TYPE_MAP .get (modelId ), devices );
167197 long startTime = System .currentTimeMillis ();
168198 concurrentInference (
169199 statement ,
170200 String .format (
171- "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time)" ,
172- modelId ),
201+ "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time), predict_length=>%d " ,
202+ modelId , predictLength ),
173203 threadCnt ,
174- loop );
204+ loop ,
205+ predictLength );
175206 long endTime = System .currentTimeMillis ();
176207 LOGGER .info (
177208 String .format (
@@ -180,4 +211,29 @@ public void concurrentGPUForecastTest(String modelId) throws SQLException, Inter
180211 statement .execute (String .format ("UNLOAD MODEL %s FROM DEVICES \" 0,1\" " , modelId ));
181212 }
182213 }
214+
215+ private void checkModelOnSpecifiedDevice (Statement statement , String modelType , String device )
216+ throws SQLException , InterruptedException {
217+ for (int retry = 0 ; retry < 10 ; retry ++) {
218+ Set <String > targetDevices = ImmutableSet .copyOf (device .split ("," ));
219+ Set <String > foundDevices = new HashSet <>();
220+ try (final ResultSet resultSet =
221+ statement .executeQuery (String .format ("SHOW LOADED MODELS %s" , device ))) {
222+ while (resultSet .next ()) {
223+ String deviceId = resultSet .getString (1 );
224+ String loadedModelType = resultSet .getString (2 );
225+ int count = resultSet .getInt (3 );
226+ if (loadedModelType .equals (modelType ) && targetDevices .contains (deviceId )) {
227+ Assert .assertTrue (count > 1 );
228+ foundDevices .add (deviceId );
229+ }
230+ }
231+ if (foundDevices .containsAll (targetDevices )) {
232+ return ;
233+ }
234+ }
235+ TimeUnit .SECONDS .sleep (3 );
236+ }
237+ Assert .fail ("Model " + modelType + " is not loaded on device " + device );
238+ }
183239}
0 commit comments