Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 54 additions & 5 deletions desktop/libs/notebook/src/notebook/connectors/flink_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
OPERATION_TOKEN = '%(username)s-%(connector_name)s' + '-operation-token'
DEFAULT_CATALOG_PARAM = "default_catalog"
DEFAULT_DATABASE_PARAM = "default_database"
LIST_ALL_FUNCTIONS_PARAM = "list_all_functions"


def query_error_handler(func):
Expand Down Expand Up @@ -74,6 +75,7 @@ def __init__(self, user, interpreter=None):
api_url = self.options['url']
self.default_catalog = self.options.get(DEFAULT_CATALOG_PARAM)
self.default_database = self.options.get(DEFAULT_DATABASE_PARAM)
self.list_all_functions = self.options.get(LIST_ALL_FUNCTIONS_PARAM)

self.db = FlinkSqlClient(user=user, api_url=api_url)

Expand Down Expand Up @@ -418,12 +420,18 @@ def _check_status_and_fetch_result(self, session_handle, operation_handle):
data = [i['fields'] for i in resp['results']['data'] if resp and resp['results'] and resp['results']['data']]
return data

def _show_databases(self):
def _show_catalogs(self):
session = self._get_session()
session_handle = session['id']
operation_handle = self.db.execute_statement(session_handle=session['id'], statement='SHOW CATALOGS')
catalog_list = self._check_status_and_fetch_result(session['id'], operation_handle['operationHandle'])

return [catalog[0] for catalog in catalog_list]

operation_handle = self.db.execute_statement(session_handle=session_handle, statement='SHOW DATABASES')
db_list = self._check_status_and_fetch_result(session_handle, operation_handle['operationHandle'])
def _show_databases(self, catalog=None):
session = self._get_session()
statement = 'SHOW DATABASES IN `%(catalog)s`' % {'catalog': catalog} if catalog else 'SHOW DATABASES'
operation_handle = self.db.execute_statement(session_handle=session['id'], statement=statement)
db_list = self._check_status_and_fetch_result(session['id'], operation_handle['operationHandle'])

return [db[0] for db in db_list]

Expand Down Expand Up @@ -461,12 +469,51 @@ def _get_columns(self, database, table):
]

def _show_functions(self, database):
if self.list_all_functions:
return self._show_all_functions()
else:
return self._show_functions_in_current_db(database)

def _show_functions_in_current_db(self, database):
session = self._get_session()
statement = 'SHOW FUNCTIONS IN `%(database)s`' % {'database': database} if database else 'SHOW FUNCTIONS'
operation_handle = self.db.execute_statement(session['id'], statement)
function_list = self._check_status_and_fetch_result(session['id'], operation_handle['operationHandle'])
return [{'name': function[0]} for function in function_list]

def _show_all_functions(self):
# Flink UDFs can be registered in any catalog in any database. This function iterates through all catalogs
# and databases to lists all defined functions. The results are then extended with Flink system functions.
# Returned user functions names are fully qualified (<catalog_name>.<database_name>.<function_name>).
result = []

for catalog in self._show_catalogs():
for database in self._show_databases(catalog):
user_functions = self._list_user_functions(catalog, database)
result.extend([{'name': f'{catalog}.{database}.{function["name"]}'} for function in user_functions])

result.extend(self._list_system_functions())
return result

def _list_system_functions(self):
# Flink allows to either list all functions or only user functions.
all_functions = self._list_functions(catalog=self.default_catalog, database=self.default_database, user_scope=False)
user_functions = self._list_user_functions(catalog=self.default_catalog, database=self.default_database)
system_functions = [f for f in all_functions if f not in user_functions]
return system_functions

def _list_user_functions(self, catalog, database):
return self._list_functions(catalog, database, user_scope=True)

def _list_functions(self, catalog, database, user_scope):
session = self._get_session()
statement = 'SHOW USER FUNCTIONS' if user_scope else 'SHOW FUNCTIONS'
if database:
statement = statement + ' IN `%(catalog)s`.`%(database)s`' % {'catalog': catalog, 'database': database}
operation_handle = self.db.execute_statement(session['id'], statement)
function_list = self._check_status_and_fetch_result(session['id'], operation_handle['operationHandle'])
return [{'name': function[0]} for function in function_list]

def _show_function(self, function_name):
session = self._get_session()
if session.get('flink_version') and session['flink_version'].startswith('2.'):
Expand All @@ -476,9 +523,11 @@ def _show_function(self, function_name):
statement='DESCRIBE FUNCTION EXTENDED %(function_name)s' % {'function_name': function_name})
properties = dict(self._check_status_and_fetch_result(session['id'], operation_handle['operationHandle']))

# Function can be overloaded (multiple signatures). But only the first signature will be returned.
signatures = properties.get('signature').split('\n')
return {
'name': function_name,
'signature': properties.get('signature'),
'signature': signatures[0],
}
else:
return {'name': function_name}
Expand Down
86 changes: 72 additions & 14 deletions desktop/libs/notebook/src/notebook/connectors/flink_sql_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,18 +86,7 @@ def test_autocomplete_operation_functions(self, client_mock):
mock_client_instance.create_session.return_value = {'sessionHandle': self.TEST_SESSION_HANDLE}
mock_client_instance.info.return_value = {'version': '2.0.0'}
mock_client_instance.execute_statement.return_value = {'operationHandle': self.TEST_OPERATION_HANDLE}
mock_client_instance.fetch_results.return_value = {
'resultType': 'PAYLOAD',
'resultKind': 'SUCCESS_WITH_CONTENT',
'results': {
'columns': [{'name': 'function name', 'logicalType': {'type': 'VARCHAR', 'nullable': True, 'length': 1000}}],
'rowFormat': 'JSON',
'data': [
{'kind': 'INSERT', 'fields': ['lower']},
{'kind': 'INSERT', 'fields': ['upper']}
]},
'nextResultUri': f'/v3/sessions/{self.TEST_SESSION_HANDLE}/operations/{self.TEST_OPERATION_HANDLE}/result/1?rowFormat=JSON'
}
mock_client_instance.fetch_results.return_value = self._list_function_payload(['lower', 'upper'])

# and: FlinkSqlApi instance with configuration
flink_api = FlinkSqlApi(self.user, interpreter=self.interpreter)
Expand All @@ -111,7 +100,54 @@ def test_autocomplete_operation_functions(self, client_mock):

# then
mock_client_instance.execute_statement.assert_called_once_with(self.TEST_SESSION_HANDLE, 'SHOW FUNCTIONS')
assert autocomplete_result == {'functions': [{'name': 'lower'}, {'name': 'upper'}]}
self._assert_autocomplete_functions(autocomplete_result, ['lower', 'upper'])

@patch('notebook.connectors.flink_sql.FlinkSqlClient')
def test_autocomplete_operation_functions_list_all(self, client_mock):
# given: mock interactions
def mock_execute_statement(session_handle, statement):
responses = {
'SHOW CATALOGS': {'operationHandle': 'show-catalogs'},
'SHOW DATABASES IN `test_catalog`': {'operationHandle': 'show-databases'},
'SHOW FUNCTIONS IN `test_catalog`.`db_a`': {'operationHandle': 'show-fns-dba'},
'SHOW USER FUNCTIONS IN `test_catalog`.`db_a`': {'operationHandle': 'show-user-fns-dba'},
'SHOW USER FUNCTIONS IN `test_catalog`.`db_b`': {'operationHandle': 'show-user-fns-dbb'},
}
return responses.get(statement)

def mock_fetch_results(session_handle, operation_handle, token):
responses = {
'show-catalogs': self._list_function_payload(['test_catalog']),
'show-databases': self._list_function_payload(['db_a', 'db_b']),
'show-fns-dba': self._list_function_payload(['test_fun_a', 'lower', 'upper']),
'show-user-fns-dba': self._list_function_payload(['test_fun_a']),
'show-user-fns-dbb': self._list_function_payload(['test_fun_b']),
}
return responses.get(operation_handle)

mock_client_instance = MagicMock()
client_mock.return_value = mock_client_instance
mock_client_instance.create_session.return_value = {'sessionHandle': self.TEST_SESSION_HANDLE}
mock_client_instance.info.return_value = {'version': '2.0.0'}
mock_client_instance.execute_statement.side_effect = mock_execute_statement
mock_client_instance.fetch_results.side_effect = mock_fetch_results

# and: FlinkSqlApi instance with configuration
self.interpreter['options']['list_all_functions'] = True
self.interpreter['options']['default_catalog'] = 'test_catalog'
self.interpreter['options']['default_database'] = 'db_a'
flink_api = FlinkSqlApi(self.user, interpreter=self.interpreter)

# and: session is created
flink_api.create_session(lang='flink', properties=None)

# when
autocomplete_result = flink_api.autocomplete(snippet='dummy', database=None, table=None, column=None,
nested=None, operation='functions')

# then
self._assert_autocomplete_functions(autocomplete_result, ['lower', 'upper', 'test_catalog.db_a.test_fun_a',
'test_catalog.db_b.test_fun_b'])

@patch('notebook.connectors.flink_sql.FlinkSqlClient')
def test_autocomplete_operation_function_flink_1_x(self, client_mock):
Expand Down Expand Up @@ -164,7 +200,8 @@ def test_autocomplete_operation_function_flink_2_x(self, client_mock):
{'kind': 'INSERT', 'fields': ['signature', 'default_catalog.default_db.test_function(values <ANY>...)']},
]
},
'nextResultUri': f'/v3/sessions/{self.TEST_SESSION_HANDLE}/operations/{self.TEST_OPERATION_HANDLE}/result/1?rowFormat=JSON'
'nextResultUri':
f'/v3/sessions/{self.TEST_SESSION_HANDLE}/operations/{self.TEST_OPERATION_HANDLE}/result/1?rowFormat=JSON'
}

# and: FlinkSqlApi instance with configuration
Expand All @@ -185,3 +222,24 @@ def test_autocomplete_operation_function_flink_2_x(self, client_mock):
assert autocomplete_result == {
'function': {'name': 'test_function', 'signature': 'default_catalog.default_db.test_function(values <ANY>...)'}
}

def _list_function_payload(self, expected_functions, session_handle=None, operation_handle=None, token=0):
session_handle = session_handle if session_handle else self.TEST_SESSION_HANDLE
operation_handle = operation_handle if operation_handle else self.TEST_OPERATION_HANDLE

return {
'resultType': 'PAYLOAD',
'resultKind': 'SUCCESS_WITH_CONTENT',
'results': {
'columns': [
{'name': 'function name', 'logicalType': {'type': 'VARCHAR', 'nullable': True, 'length': 1000}}],
'rowFormat': 'JSON',
'data': [
{'kind': 'INSERT', 'fields': [fname]} for fname in expected_functions
]},
'nextResultUri': f'/v3/sessions/{session_handle}/operations/{operation_handle}/result/1?rowFormat=JSON'
}

def _assert_autocomplete_functions(self, autocomplete_result, expected_fun_names):
actual_fun_names = set([f['name'] for f in autocomplete_result['functions']])
assert set(expected_fun_names) == actual_fun_names