diff --git a/cvmassistants/secretprovider/secret-provider-agent/src/secret_provider_agent.c b/cvmassistants/secretprovider/secret-provider-agent/src/secret_provider_agent.c index df6f405..b847538 100644 --- a/cvmassistants/secretprovider/secret-provider-agent/src/secret_provider_agent.c +++ b/cvmassistants/secretprovider/secret-provider-agent/src/secret_provider_agent.c @@ -51,7 +51,7 @@ char* get_secret_from_sbs_through_rats_tls(rats_tls_log_level_t log_level, bool mutual, char* ip, int port, - bool appid_flag) { + char* app_id) { bool validation_error = false; if (attester_type == NULL || strlen(attester_type) >= ENCLAVE_ATTESTER_TYPE_NAME_SIZE) { @@ -89,20 +89,13 @@ char* get_secret_from_sbs_through_rats_tls(rats_tls_log_level_t log_level, rats_tls_conf_t conf; memset(&conf, 0, sizeof(conf)); - char* app_id; claim_t custom_claims[1]; - if (appid_flag) { - app_id = getenv("appId"); - if (NULL != app_id) { - custom_claims[0].name = "appId"; - custom_claims[0].value = (uint8_t*)app_id; - custom_claims[0].value_size = strlen(app_id); - conf.custom_claims = (claim_t*)custom_claims; - conf.custom_claims_length = 1; - } else { - LOG_ERROR("Could not read the app_id from env"); - return NULL; - } + if (app_id != NULL) { + custom_claims[0].name = "appId"; + custom_claims[0].value = (uint8_t*)app_id; + custom_claims[0].value_size = strlen(app_id); + conf.custom_claims = (claim_t*)custom_claims; + conf.custom_claims_length = 1; } conf.log_level = log_level; @@ -229,27 +222,14 @@ int main(int argc, char** argv) { setvbuf(stdout, NULL, _IONBF, 0); char* secret = ""; LOG_INFO("Try to get key from SBS"); - char* sbs_endpoint = getenv("sbsEndpoint"); - if (NULL == sbs_endpoint) { - LOG_ERROR("SBS mode must config sbsEndpoint environment variable"); - return -1; - } - - LOG_DEBUG("Config of SBS endpoint is %s", sbs_endpoint); char* secret_save_path = NULL; + char* sbs_endpoint = NULL; char* srv_ip = NULL; char* str_port = NULL; int port; - srv_ip = strtok(sbs_endpoint, ":"); - str_port = strtok(NULL, ":"); - if (NULL == str_port) { - LOG_ERROR("sbsEndpoint format error, eg: 127.0.0.1:5443"); - return -1; - } - port = atoi(str_port); - char* const short_options = "a:v:t:c:ml:fs:h"; + char* const short_options = "a:v:t:c:ml:s:i:e:h"; struct option long_options[] = { {"attester", required_argument, NULL, 'a'}, {"verifier", required_argument, NULL, 'v'}, @@ -257,8 +237,9 @@ int main(int argc, char** argv) { {"crypto", required_argument, NULL, 'c'}, {"mutual", no_argument, NULL, 'm'}, {"log-level", required_argument, NULL, 'l'}, - {"appId", no_argument, NULL, 'f'}, {"savePath", required_argument, NULL, 's'}, + {"appId", required_argument, NULL, 'i'}, + {"sbsEndpoint", required_argument, NULL, 'e'}, {"help", no_argument, NULL, 'h'}, {0, 0, 0, 0}}; @@ -267,7 +248,7 @@ int main(int argc, char** argv) { char* tls_type = ""; char* crypto_type = ""; bool mutual = true; - bool appid_flag = false; + char* app_id = NULL; int opt; do { opt = getopt_long(argc, argv, short_options, long_options, NULL); @@ -298,12 +279,15 @@ int main(int argc, char** argv) { else if (!strcasecmp(optarg, "off")) log_level = RATS_TLS_LOG_LEVEL_NONE; break; - case 'f': - appid_flag = true; + case 'i': + app_id = optarg; break; case 's': secret_save_path = optarg; break; + case 'e': + sbs_endpoint = optarg; + break; case -1: break; case 'h': @@ -321,8 +305,9 @@ int main(int argc, char** argv) { " --port/-p set the listening tcp port\n" " --debug-enclave/-D set to enable enclave debugging\n" " --verdictd/-E set to connect verdictd based on EAA protocol\n" - " --appId/-f need to add appid to claims\n" - " --savePath/-s save secret to local path" + " --appId/-i value set the appId value to add to claims\n" + " --savePath/-s save secret to local path\n" + " --sbsEndpoint/-e set the SBS endpoint (format: IP:PORT)\n" " --help/-h show the usage\n"); exit(-1); default: @@ -332,6 +317,25 @@ int main(int argc, char** argv) { LOG_INFO("Selected log level %d", log_level); + if (sbs_endpoint == NULL) { + LOG_ERROR("SBS mode must provide sbsEndpoint argument (--sbsEndpoint/-e)"); + return -1; + } + + LOG_DEBUG("Config of SBS endpoint is %s", sbs_endpoint); + + srv_ip = strtok(sbs_endpoint, ":"); + str_port = strtok(NULL, ":"); + if (NULL == str_port) { + LOG_ERROR("sbsEndpoint format error, eg: 127.0.0.1:5443"); + return -1; + } + port = atoi(str_port); + if (port == 0) { + LOG_ERROR("Port is invalid, got %s", str_port); + return -1; + } + if (secret_save_path == NULL) { LOG_ERROR("Path to store secret locally is missing"); return -1; @@ -344,7 +348,7 @@ int main(int argc, char** argv) { secret = get_secret_from_sbs_through_rats_tls(log_level, attester_type, verifier_type, tls_type, crypto_type, mutual, srv_ip, - port, appid_flag); + port, app_id); if (secret == NULL) { LOG_ERROR("Get secret from SBS failed"); return -1;