Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/examples/connection_examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ Create a connection that uses a specific gateway for secure access.
```
fab create .connections/conn.Connection -P gateway=MyVnetGateway.Gateway,connectionDetails.type=SQL,connectionDetails.parameters.server=<server>,connectionDetails.parameters.database=sales,credentialDetails.type=Basic,credentialDetails.username=<username>,credentialDetails.password=<password>
```
#### Create Connection with On-premises Gateway

Create a connection that uses a specific on-premises gateway with encrypted credentials for secure access

```
fab create .connections/conn.Connection -P gateway=MyVnetGateway.Gateway,connectionDetails.type=SQL,connectionDetails.parameters.server=<server>,connectionDetails.parameters.database=sales,credentialDetails.type=Basic,credentialDetails.values=[{"gatewayId":"<gatewayId>", "encryptedCredentials": "<encryptedCredentials>"}]
```

#### Create Connection with All Parameters

Expand Down
90 changes: 74 additions & 16 deletions src/fabric_cli/utils/fab_cmd_mkdir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,33 +359,41 @@ def check_required_params(params, required_params):
)


def _validate_credential_params(cred_type, provided_cred_params):
ignored_params = []
params = {}
def _get_params_per_cred_type(cred_type, is_on_premises_gateway):
match cred_type:
case "Anonymous" | "WindowsWithoutImpersonation" | "WorkspaceIdentity":
param_keys = []
return []
case "Basic" | "Windows":
param_keys = ["username", "password"]
if is_on_premises_gateway:
return ["values"]
else:
return ["username", "password"]
case "Key":
param_keys = ["key"]
return ["key"]
case "OAuth2":
raise FabricCLIError(
"OAuth2 credential type is not supported",
fab_constant.ERROR_NOT_SUPPORTED,
)
case "ServicePrincipal":
param_keys = [
return [
"servicePrincipalClientId",
"servicePrincipalSecret",
"tenantId",
]
case "SharedAccessSignature":
param_keys = ["token"]
return ["token"]
case _:
utils_ui.print_warning(
f"Unsupported credential type {cred_type}. Skipping validation"
)
return []


def _validate_credential_params(cred_type, provided_cred_params, is_on_premises_gateway):
ignored_params = []
params = {}
param_keys = _get_params_per_cred_type(cred_type, is_on_premises_gateway)

missing_params = [
key for key in param_keys if key.lower() not in provided_cred_params
Expand All @@ -405,12 +413,53 @@ def _validate_credential_params(cred_type, provided_cred_params):
utils_ui.print_warning(
f"Ignoring unsupported parameters for credential type {cred_type}: {ignored_params}"
)
if is_on_premises_gateway:
provided_cred_params["values"] = _validate_and_get_on_premises_gateway_credential_values(provided_cred_params.get("values"))

for key in param_keys:
params[key] = provided_cred_params[key.lower()]

return params

def _validate_and_get_on_premises_gateway_credential_values(cred_values):
# Validate that the provided credential values are in the correct format
# The values should be a list of JSON objects with the following keys:
# - gatewayId: The ID of the OnPremisesGateway
# - encryptedCredentials: The encrypted credentials for the gateway
# The values should be a list of JSON objects
# Validate all items are JSON objects first (with early break)
for item in cred_values:
if not isinstance(item, dict):
raise FabricCLIError(
ErrorMessages.Common.invalid_json_format(),
fab_constant.ERROR_INVALID_INPUT,
)

param_values_keys = ["gatewayId", "encryptedCredentials"]
missing_params = [
key for key in param_values_keys
if not all(key.lower() in {k.lower() for k in item.keys()} for item in cred_values)
]
if len(missing_params) > 0:
raise FabricCLIError(
f"Missing parameters for credential values in OnPremesisGateway connectivity type: {missing_params}",
fab_constant.ERROR_INVALID_INPUT,
)

ignored_params = [
key
for item in cred_values
for key in item.keys()
if key not in [k.lower() for k in param_values_keys]
]
if len(ignored_params) > 0:
utils_ui.print_warning(
f"Ignoring unsupported parameters for on-premises gateway: {ignored_params}"
)

return [{key: item[key.lower()] for key in param_values_keys if key.lower() in item} for item in cred_values]



def get_connection_config_from_params(payload, con_type, con_type_def, params):
connection_request = payload
Expand Down Expand Up @@ -537,13 +586,6 @@ def get_connection_config_from_params(payload, con_type, con_type_def, params):
fab_constant.ERROR_INVALID_INPUT,
)

if missing_params:
missing_params_str = ", ".join(missing_params)
raise FabricCLIError(
f"Missing parameter(s) {missing_params_str} for creation method {c_method}",
fab_constant.ERROR_INVALID_INPUT,
)

connection_request["connectionDetails"] = {
"type": con_type,
"creationMethod": creation_method["name"],
Expand All @@ -563,6 +605,17 @@ def get_connection_config_from_params(payload, con_type, con_type_def, params):
"password": "********"
}
}
or in case of OnPremisesGateway:
"credentialDetails": {
"credentialType": "Basic",
"singleSignOnType": "None",
"connectionEncryption": "NotEncrypted",
"skipTestConnection": false,
"credentials": {
"credentialType": "Basic",
"values": [{gatewayId: "gatewayId", encryptedCredentials: "**********"}]
}
}
"""
sup_cred_types = ", ".join(con_type_def["supportedCredentialTypes"])
if not params.get("credentialdetails"):
Expand Down Expand Up @@ -603,7 +656,8 @@ def get_connection_config_from_params(payload, con_type, con_type_def, params):
if "skiptestconnection" in provided_cred_params:
provided_cred_params.pop("skiptestconnection")

connection_params = _validate_credential_params(cred_type, provided_cred_params)
is_on_premises_gateway = connection_request.get("connectivityType").lower() == "onpremisesgateway"
connection_params = _validate_credential_params(cred_type, provided_cred_params, is_on_premises_gateway)

connection_request["credentialDetails"] = {
"singleSignOnType": singleSignOnType,
Expand All @@ -613,6 +667,10 @@ def get_connection_config_from_params(payload, con_type, con_type_def, params):
}
connection_request["credentialDetails"]["credentials"]["credentialType"] = cred_type

# Build credential details based on connectivity type
if is_on_premises_gateway:
connection_request["credentialDetails"]["credentials"]["values"] = connection_params.get("values")

return connection_request


Expand Down
20 changes: 19 additions & 1 deletion src/fabric_cli/utils/fab_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,27 @@ def get_dict_from_parameter(
key, rest = param.split(".", 1)
return {key: get_dict_from_parameter(rest, value, max_depth, current_depth + 1)}
else:
clean_value = value.replace("'", "").replace('"', "")
clean_value = try_get_json_value_from_string(value)
return {param: clean_value}

def try_get_json_value_from_string(value: str) -> Any:
"""
Try to parse a string as JSON, with special handling for array parameters.

Args:
value: String that may contain JSON data

Returns:
Parsed JSON if valid, otherwise original string
"""
try:
parse = json.loads(value)
if (isinstance(parse, list) and all(isinstance(item, dict) for item in parse)):
return parse
except json.JSONDecodeError:
# For non-JSON values, return as-is without quote stripping
pass
return value.replace("'", "").replace('"', "")

def merge_dicts(dict1: dict, dict2: dict) -> dict:
"""
Expand Down
72 changes: 71 additions & 1 deletion tests/test_commands/api_processors/connection_api_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@

import json
from tests.test_commands.api_processors.base_api_processor import BaseAPIProcessor
from tests.test_commands.api_processors.utils import load_response_json_body
from tests.test_commands.api_processors.utils import (
load_request_json_body,
load_response_json_body,
)
from tests.test_commands.data.static_test_data import get_mock_data, get_static_data


class ConnectionAPIProcessor(BaseAPIProcessor):
Expand All @@ -13,6 +17,22 @@ def __init__(self, generated_name_mapping):
self.generated_name_mapping = generated_name_mapping

def try_process_request(self, request) -> bool:
uri = request.uri
# First, handle URI mocking for gateway IDs
self._mock_gateway_id_in_uri(request)

# Handle connection creation and listing
if uri.lower() == self.CONNECTIONS_URI.lower():
method = request.method
if method == "POST":
"""https://learn.microsoft.com/en-us/rest/api/fabric/core/connections/create-connection?tabs=HTTP"""
self._handle_post_request(request)
return True

# Handle supported connection types with gateway ID query parameter
if uri.lower().startswith(f"{self.CONNECTIONS_URI.lower()}/supportedconnectiontypes"):
return True

return False

def try_process_response(self, request, response) -> bool:
Expand All @@ -23,6 +43,9 @@ def try_process_response(self, request, response) -> bool:
if method == "GET":
"""https://learn.microsoft.com/en-us/rest/api/fabric/core/connections/list-connections?tabs=HTTP"""
self._handle_get_response(response)
if method == "POST":
"""https://learn.microsoft.com/en-us/rest/api/fabric/core/connections/create-connection?tabs=HTTP"""
self._handle_post_response(response)
return True
return False

Expand All @@ -34,9 +57,56 @@ def _handle_get_response(self, response):
new_value = []
for item in data["value"]:
if item.get("displayName") in self.generated_name_mapping:
self._mock_gateway_references(item)
new_value.append(item)

data["value"] = new_value

new_body_str = json.dumps(data)
response["body"]["string"] = new_body_str.encode("utf-8")

def _handle_post_request(self, request):
"""Handle POST request for connection creation"""
data = load_request_json_body(request)
if not data:
return

self._mock_gateway_references(data)

new_body_str = json.dumps(data)
request.body = new_body_str

def _handle_post_response(self, response):
"""Handle POST response for connection creation"""
data = load_response_json_body(response)
if not data:
return

self._mock_gateway_references(data)

new_body_str = json.dumps(data)
response["body"]["string"] = new_body_str.encode("utf-8")

def _mock_gateway_references(self, obj):
"""Mock gateway ID references in connection objects"""
static_gateway_id = get_static_data().onpremises_gateway_details.id
mock_gateway_id = get_mock_data().onpremises_gateway_details.id

# Mock direct gatewayId field
if "gatewayId" in obj and obj["gatewayId"] == static_gateway_id:
obj["gatewayId"] = f"{mock_gateway_id}"

# Mock gatewayId in credentialDetails.values arrays
if "credentialDetails" in obj and "values" in obj["credentialDetails"]:
for cred_value in obj["credentialDetails"]["values"]:
if isinstance(cred_value, dict) and "gatewayId" in cred_value:
if cred_value["gatewayId"] == static_gateway_id:
cred_value["gatewayId"] = mock_gateway_id

def _mock_gateway_id_in_uri(self, request):
"""Mock gateway IDs in request URIs and query parameters"""
static_gateway_id = get_static_data().onpremises_gateway_details.id
mock_gateway_id = get_mock_data().onpremises_gateway_details.id

# Replace gateway ID in URI path and query parameters
request.uri = request.uri.replace(static_gateway_id, mock_gateway_id)
Loading