diff --git a/catalogs/hive-metastore-common/build.gradle.kts b/catalogs/hive-metastore-common/build.gradle.kts index 406b0fd31cf..de365651891 100644 --- a/catalogs/hive-metastore-common/build.gradle.kts +++ b/catalogs/hive-metastore-common/build.gradle.kts @@ -120,6 +120,7 @@ dependencies { exclude("org.slf4j") } testImplementation(libs.junit.jupiter.api) + testImplementation(libs.mockito.core) testImplementation(libs.woodstox.core) testImplementation(libs.testcontainers) testImplementation(project(":integration-test-common", "testArtifacts")) diff --git a/catalogs/hive-metastore-common/src/main/java/org/apache/gravitino/hive/HiveClientPool.java b/catalogs/hive-metastore-common/src/main/java/org/apache/gravitino/hive/HiveClientPool.java index 7fa0098a34d..e4d78e58c76 100644 --- a/catalogs/hive-metastore-common/src/main/java/org/apache/gravitino/hive/HiveClientPool.java +++ b/catalogs/hive-metastore-common/src/main/java/org/apache/gravitino/hive/HiveClientPool.java @@ -41,7 +41,7 @@ public class HiveClientPool extends ClientPoolImplSource: hive-metastore/src/test/java/org/apache/iceberg/hive/TestHiveClientPool.java + */ +public class TestHiveClientPool { + + private HiveClientPool clients; + + @BeforeEach + public void before() { + HiveClientPool clientPool = new HiveClientPool("hive", 2, new Properties()); + clients = Mockito.spy(clientPool); + } + + @AfterEach + public void after() { + clients.close(); + clients = null; + } + + @Test + public void testNewClientFailure() { + Mockito.doThrow(new RuntimeException("Connection exception")).when(clients).newClient(); + RuntimeException ex = assertThrows(RuntimeException.class, () -> clients.run(Object::toString)); + assertEquals("Connection exception", ex.getMessage()); + } + + @Test + public void testReconnect() { + HiveClient hiveClient = newClient(); + + String metaMessage = "Got exception: org.apache.thrift.transport.TTransportException"; + Mockito.doThrow(new GravitinoRuntimeException(metaMessage)) + .when(hiveClient) + .getAllDatabases(""); + + GravitinoRuntimeException ex = + assertThrows( + GravitinoRuntimeException.class, + () -> clients.run(client -> client.getAllDatabases(""))); + assertEquals("Got exception: org.apache.thrift.transport.TTransportException", ex.getMessage()); + // Verify that the method is never called. + Mockito.verify(clients, Mockito.never()).reconnect(hiveClient); + } + + @Test + public void testClose() throws Exception { + HiveClient hiveClient = newClient(); + + List databases = Lists.newArrayList("db1", "db2"); + Mockito.doReturn(databases).when(hiveClient).getAllDatabases(""); + assertEquals(clients.run(client -> client.getAllDatabases("")), databases); + + clients.close(); + assertTrue(clients.isClosed()); + Mockito.verify(hiveClient).close(); + } + + private HiveClient newClient() { + HiveClient hiveClient = Mockito.mock(HiveClient.class); + Mockito.doReturn(hiveClient).when(clients).newClient(); + return hiveClient; + } +}