@@ -81,61 +81,103 @@ private static void prepareDataForTableModel() throws SQLException {
8181
8282 @ Test
8383 public void concurrentCPUCallInferenceTest () throws SQLException , InterruptedException {
84+ concurrentCPUCallInferenceTest ("timer_xl" );
85+ concurrentCPUCallInferenceTest ("sundial" );
86+ }
87+
88+ private void concurrentCPUCallInferenceTest (String modelId )
89+ throws SQLException , InterruptedException {
8490 try (Connection connection = EnvFactory .getEnv ().getConnection (BaseEnv .TREE_SQL_DIALECT );
8591 Statement statement = connection .createStatement ()) {
86- statement .execute ("LOAD MODEL sundial TO DEVICES \" cpu\" " );
87- concurrentInference (statement , "CALL INFERENCE(sundial, \" SELECT s FROM root.AI\" )" , 4 , 10 );
92+ final int threadCnt = 4 ;
93+ final int loop = 10 ;
94+ statement .execute (String .format ("LOAD MODEL %s TO DEVICES \" cpu\" " , modelId ));
95+ concurrentInference (
96+ statement ,
97+ String .format ("CALL INFERENCE(%s, \" SELECT s FROM root.AI\" )" , modelId ),
98+ threadCnt ,
99+ loop );
100+ statement .execute (String .format ("UNLOAD MODEL %s FROM DEVICES \" cpu\" " , modelId ));
88101 }
89102 }
90103
91104 @ Test
92105 public void concurrentGPUCallInferenceTest () throws SQLException , InterruptedException {
106+ concurrentGPUCallInferenceTest ("timer_xl" );
107+ concurrentGPUCallInferenceTest ("sundial" );
108+ }
109+
110+ private void concurrentGPUCallInferenceTest (String modelId )
111+ throws SQLException , InterruptedException {
93112 try (Connection connection = EnvFactory .getEnv ().getConnection (BaseEnv .TREE_SQL_DIALECT );
94113 Statement statement = connection .createStatement ()) {
95- statement .execute ("LOAD MODEL sundial TO DEVICES \" 0,1\" " );
96- concurrentInference (statement , "CALL INFERENCE(sundial, \" SELECT s FROM root.AI\" )" , 10 , 100 );
114+ final int threadCnt = 4 ;
115+ final int loop = 10 ;
116+ statement .execute (String .format ("LOAD MODEL %s TO DEVICES \" 0,1\" " , modelId ));
117+ concurrentInference (
118+ statement ,
119+ String .format ("CALL INFERENCE(%s, \" SELECT s FROM root.AI\" )" , modelId ),
120+ threadCnt ,
121+ loop );
122+ statement .execute (String .format ("UNLOAD MODEL %s FROM DEVICES \" 0,1\" " , modelId ));
97123 }
98124 }
99125
100126 @ Test
101127 public void concurrentCPUForecastTest () throws SQLException , InterruptedException {
102- try (Connection connection = EnvFactory .getEnv ().getConnection (BaseEnv .TREE_SQL_DIALECT );
128+ concurrentCPUForecastTest ("timer_xl" );
129+ concurrentCPUForecastTest ("sundial" );
130+ }
131+
132+ private void concurrentCPUForecastTest (String modelId ) throws SQLException , InterruptedException {
133+ try (Connection connection = EnvFactory .getEnv ().getConnection (BaseEnv .TABLE_SQL_DIALECT );
103134 Statement statement = connection .createStatement ()) {
104135 final int threadCnt = 4 ;
105136 final int loop = 10 ;
106- statement .execute ("LOAD MODEL sundial TO DEVICES \" cpu\" " );
137+ statement .execute (String . format ( "LOAD MODEL %s TO DEVICES \" cpu\" " , modelId ) );
107138 long startTime = System .currentTimeMillis ();
108139 concurrentInference (
109140 statement ,
110- "SELECT * FROM FORECAST(model_id=>'sundial', input=>(SELECT time,s FROM root.AI) ORDER BY time)" ,
141+ String .format (
142+ "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time)" ,
143+ modelId ),
111144 threadCnt ,
112145 loop );
113146 long endTime = System .currentTimeMillis ();
114147 LOGGER .info (
115148 String .format (
116- "Timer-Sundial concurrent inference %d reqs (%d threads, %d loops) in CPU takes time: %dms" ,
117- threadCnt * loop , threadCnt , loop , endTime - startTime ));
149+ "Model %s concurrent inference %d reqs (%d threads, %d loops) in CPU takes time: %dms" ,
150+ modelId , threadCnt * loop , threadCnt , loop , endTime - startTime ));
151+ statement .execute (String .format ("UNLOAD MODEL %s FROM DEVICES \" cpu\" " , modelId ));
118152 }
119153 }
120154
121155 @ Test
122156 public void concurrentGPUForecastTest () throws SQLException , InterruptedException {
123- try (Connection connection = EnvFactory .getEnv ().getConnection (BaseEnv .TREE_SQL_DIALECT );
157+ concurrentGPUForecastTest ("timer_xl" );
158+ concurrentGPUForecastTest ("sundial" );
159+ }
160+
161+ public void concurrentGPUForecastTest (String modelId ) throws SQLException , InterruptedException {
162+ try (Connection connection = EnvFactory .getEnv ().getConnection (BaseEnv .TABLE_SQL_DIALECT );
124163 Statement statement = connection .createStatement ()) {
125164 final int threadCnt = 10 ;
126165 final int loop = 100 ;
127- statement .execute ("LOAD MODEL sundial TO DEVICES \" 0,1\" " );
166+ statement .execute (String . format ( "LOAD MODEL %s TO DEVICES \" 0,1\" " , modelId ) );
128167 long startTime = System .currentTimeMillis ();
129168 concurrentInference (
130169 statement ,
131- "SELECT * FROM FORECAST(model_id=>'sundial', input=>(SELECT time,s FROM root.AI) ORDER BY time)" ,
170+ String .format (
171+ "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time)" ,
172+ modelId ),
132173 threadCnt ,
133174 loop );
134175 long endTime = System .currentTimeMillis ();
135176 LOGGER .info (
136177 String .format (
137- "Timer-Sundial concurrent inference %d reqs (%d threads, %d loops) in GPU takes time: %dms" ,
138- threadCnt * loop , threadCnt , loop , endTime - startTime ));
178+ "Model %s concurrent inference %d reqs (%d threads, %d loops) in GPU takes time: %dms" ,
179+ modelId , threadCnt * loop , threadCnt , loop , endTime - startTime ));
180+ statement .execute (String .format ("UNLOAD MODEL %s FROM DEVICES \" 0,1\" " , modelId ));
139181 }
140182 }
141183}
0 commit comments