diff --git a/desktop/libs/notebook/src/notebook/connectors/flink_sql.py b/desktop/libs/notebook/src/notebook/connectors/flink_sql.py index 38f2e6c2217..d4f4ec027dc 100644 --- a/desktop/libs/notebook/src/notebook/connectors/flink_sql.py +++ b/desktop/libs/notebook/src/notebook/connectors/flink_sql.py @@ -35,6 +35,7 @@ OPERATION_TOKEN = '%(username)s-%(connector_name)s' + '-operation-token' DEFAULT_CATALOG_PARAM = "default_catalog" DEFAULT_DATABASE_PARAM = "default_database" +LIST_ALL_FUNCTIONS_PARAM = "list_all_functions" def query_error_handler(func): @@ -74,6 +75,7 @@ def __init__(self, user, interpreter=None): api_url = self.options['url'] self.default_catalog = self.options.get(DEFAULT_CATALOG_PARAM) self.default_database = self.options.get(DEFAULT_DATABASE_PARAM) + self.list_all_functions = self.options.get(LIST_ALL_FUNCTIONS_PARAM) self.db = FlinkSqlClient(user=user, api_url=api_url) @@ -418,12 +420,18 @@ def _check_status_and_fetch_result(self, session_handle, operation_handle): data = [i['fields'] for i in resp['results']['data'] if resp and resp['results'] and resp['results']['data']] return data - def _show_databases(self): + def _show_catalogs(self): session = self._get_session() - session_handle = session['id'] + operation_handle = self.db.execute_statement(session_handle=session['id'], statement='SHOW CATALOGS') + catalog_list = self._check_status_and_fetch_result(session['id'], operation_handle['operationHandle']) + + return [catalog[0] for catalog in catalog_list] - operation_handle = self.db.execute_statement(session_handle=session_handle, statement='SHOW DATABASES') - db_list = self._check_status_and_fetch_result(session_handle, operation_handle['operationHandle']) + def _show_databases(self, catalog=None): + session = self._get_session() + statement = 'SHOW DATABASES IN `%(catalog)s`' % {'catalog': catalog} if catalog else 'SHOW DATABASES' + operation_handle = self.db.execute_statement(session_handle=session['id'], statement=statement) + db_list = self._check_status_and_fetch_result(session['id'], operation_handle['operationHandle']) return [db[0] for db in db_list] @@ -461,12 +469,51 @@ def _get_columns(self, database, table): ] def _show_functions(self, database): + if self.list_all_functions: + return self._show_all_functions() + else: + return self._show_functions_in_current_db(database) + + def _show_functions_in_current_db(self, database): session = self._get_session() statement = 'SHOW FUNCTIONS IN `%(database)s`' % {'database': database} if database else 'SHOW FUNCTIONS' operation_handle = self.db.execute_statement(session['id'], statement) function_list = self._check_status_and_fetch_result(session['id'], operation_handle['operationHandle']) return [{'name': function[0]} for function in function_list] + def _show_all_functions(self): + # Flink UDFs can be registered in any catalog in any database. This function iterates through all catalogs + # and databases to lists all defined functions. The results are then extended with Flink system functions. + # Returned user functions names are fully qualified (..). + result = [] + + for catalog in self._show_catalogs(): + for database in self._show_databases(catalog): + user_functions = self._list_user_functions(catalog, database) + result.extend([{'name': f'{catalog}.{database}.{function["name"]}'} for function in user_functions]) + + result.extend(self._list_system_functions()) + return result + + def _list_system_functions(self): + # Flink allows to either list all functions or only user functions. + all_functions = self._list_functions(catalog=self.default_catalog, database=self.default_database, user_scope=False) + user_functions = self._list_user_functions(catalog=self.default_catalog, database=self.default_database) + system_functions = [f for f in all_functions if f not in user_functions] + return system_functions + + def _list_user_functions(self, catalog, database): + return self._list_functions(catalog, database, user_scope=True) + + def _list_functions(self, catalog, database, user_scope): + session = self._get_session() + statement = 'SHOW USER FUNCTIONS' if user_scope else 'SHOW FUNCTIONS' + if database: + statement = statement + ' IN `%(catalog)s`.`%(database)s`' % {'catalog': catalog, 'database': database} + operation_handle = self.db.execute_statement(session['id'], statement) + function_list = self._check_status_and_fetch_result(session['id'], operation_handle['operationHandle']) + return [{'name': function[0]} for function in function_list] + def _show_function(self, function_name): session = self._get_session() if session.get('flink_version') and session['flink_version'].startswith('2.'): @@ -476,9 +523,11 @@ def _show_function(self, function_name): statement='DESCRIBE FUNCTION EXTENDED %(function_name)s' % {'function_name': function_name}) properties = dict(self._check_status_and_fetch_result(session['id'], operation_handle['operationHandle'])) + # Function can be overloaded (multiple signatures). But only the first signature will be returned. + signatures = properties.get('signature').split('\n') return { 'name': function_name, - 'signature': properties.get('signature'), + 'signature': signatures[0], } else: return {'name': function_name} diff --git a/desktop/libs/notebook/src/notebook/connectors/flink_sql_tests.py b/desktop/libs/notebook/src/notebook/connectors/flink_sql_tests.py index 07cfde58190..f6ca3e8f6ac 100644 --- a/desktop/libs/notebook/src/notebook/connectors/flink_sql_tests.py +++ b/desktop/libs/notebook/src/notebook/connectors/flink_sql_tests.py @@ -86,18 +86,7 @@ def test_autocomplete_operation_functions(self, client_mock): mock_client_instance.create_session.return_value = {'sessionHandle': self.TEST_SESSION_HANDLE} mock_client_instance.info.return_value = {'version': '2.0.0'} mock_client_instance.execute_statement.return_value = {'operationHandle': self.TEST_OPERATION_HANDLE} - mock_client_instance.fetch_results.return_value = { - 'resultType': 'PAYLOAD', - 'resultKind': 'SUCCESS_WITH_CONTENT', - 'results': { - 'columns': [{'name': 'function name', 'logicalType': {'type': 'VARCHAR', 'nullable': True, 'length': 1000}}], - 'rowFormat': 'JSON', - 'data': [ - {'kind': 'INSERT', 'fields': ['lower']}, - {'kind': 'INSERT', 'fields': ['upper']} - ]}, - 'nextResultUri': f'/v3/sessions/{self.TEST_SESSION_HANDLE}/operations/{self.TEST_OPERATION_HANDLE}/result/1?rowFormat=JSON' - } + mock_client_instance.fetch_results.return_value = self._list_function_payload(['lower', 'upper']) # and: FlinkSqlApi instance with configuration flink_api = FlinkSqlApi(self.user, interpreter=self.interpreter) @@ -111,7 +100,54 @@ def test_autocomplete_operation_functions(self, client_mock): # then mock_client_instance.execute_statement.assert_called_once_with(self.TEST_SESSION_HANDLE, 'SHOW FUNCTIONS') - assert autocomplete_result == {'functions': [{'name': 'lower'}, {'name': 'upper'}]} + self._assert_autocomplete_functions(autocomplete_result, ['lower', 'upper']) + + @patch('notebook.connectors.flink_sql.FlinkSqlClient') + def test_autocomplete_operation_functions_list_all(self, client_mock): + # given: mock interactions + def mock_execute_statement(session_handle, statement): + responses = { + 'SHOW CATALOGS': {'operationHandle': 'show-catalogs'}, + 'SHOW DATABASES IN `test_catalog`': {'operationHandle': 'show-databases'}, + 'SHOW FUNCTIONS IN `test_catalog`.`db_a`': {'operationHandle': 'show-fns-dba'}, + 'SHOW USER FUNCTIONS IN `test_catalog`.`db_a`': {'operationHandle': 'show-user-fns-dba'}, + 'SHOW USER FUNCTIONS IN `test_catalog`.`db_b`': {'operationHandle': 'show-user-fns-dbb'}, + } + return responses.get(statement) + + def mock_fetch_results(session_handle, operation_handle, token): + responses = { + 'show-catalogs': self._list_function_payload(['test_catalog']), + 'show-databases': self._list_function_payload(['db_a', 'db_b']), + 'show-fns-dba': self._list_function_payload(['test_fun_a', 'lower', 'upper']), + 'show-user-fns-dba': self._list_function_payload(['test_fun_a']), + 'show-user-fns-dbb': self._list_function_payload(['test_fun_b']), + } + return responses.get(operation_handle) + + mock_client_instance = MagicMock() + client_mock.return_value = mock_client_instance + mock_client_instance.create_session.return_value = {'sessionHandle': self.TEST_SESSION_HANDLE} + mock_client_instance.info.return_value = {'version': '2.0.0'} + mock_client_instance.execute_statement.side_effect = mock_execute_statement + mock_client_instance.fetch_results.side_effect = mock_fetch_results + + # and: FlinkSqlApi instance with configuration + self.interpreter['options']['list_all_functions'] = True + self.interpreter['options']['default_catalog'] = 'test_catalog' + self.interpreter['options']['default_database'] = 'db_a' + flink_api = FlinkSqlApi(self.user, interpreter=self.interpreter) + + # and: session is created + flink_api.create_session(lang='flink', properties=None) + + # when + autocomplete_result = flink_api.autocomplete(snippet='dummy', database=None, table=None, column=None, + nested=None, operation='functions') + + # then + self._assert_autocomplete_functions(autocomplete_result, ['lower', 'upper', 'test_catalog.db_a.test_fun_a', + 'test_catalog.db_b.test_fun_b']) @patch('notebook.connectors.flink_sql.FlinkSqlClient') def test_autocomplete_operation_function_flink_1_x(self, client_mock): @@ -164,7 +200,8 @@ def test_autocomplete_operation_function_flink_2_x(self, client_mock): {'kind': 'INSERT', 'fields': ['signature', 'default_catalog.default_db.test_function(values ...)']}, ] }, - 'nextResultUri': f'/v3/sessions/{self.TEST_SESSION_HANDLE}/operations/{self.TEST_OPERATION_HANDLE}/result/1?rowFormat=JSON' + 'nextResultUri': + f'/v3/sessions/{self.TEST_SESSION_HANDLE}/operations/{self.TEST_OPERATION_HANDLE}/result/1?rowFormat=JSON' } # and: FlinkSqlApi instance with configuration @@ -185,3 +222,24 @@ def test_autocomplete_operation_function_flink_2_x(self, client_mock): assert autocomplete_result == { 'function': {'name': 'test_function', 'signature': 'default_catalog.default_db.test_function(values ...)'} } + + def _list_function_payload(self, expected_functions, session_handle=None, operation_handle=None, token=0): + session_handle = session_handle if session_handle else self.TEST_SESSION_HANDLE + operation_handle = operation_handle if operation_handle else self.TEST_OPERATION_HANDLE + + return { + 'resultType': 'PAYLOAD', + 'resultKind': 'SUCCESS_WITH_CONTENT', + 'results': { + 'columns': [ + {'name': 'function name', 'logicalType': {'type': 'VARCHAR', 'nullable': True, 'length': 1000}}], + 'rowFormat': 'JSON', + 'data': [ + {'kind': 'INSERT', 'fields': [fname]} for fname in expected_functions + ]}, + 'nextResultUri': f'/v3/sessions/{session_handle}/operations/{operation_handle}/result/1?rowFormat=JSON' + } + + def _assert_autocomplete_functions(self, autocomplete_result, expected_fun_names): + actual_fun_names = set([f['name'] for f in autocomplete_result['functions']]) + assert set(expected_fun_names) == actual_fun_names