Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/benchmark/run_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def main(parameter_table: str, *, dataset="seer", repeats=1):
writer.writerow({"comment": "Not enough features"})
print("Skipping")
except DockerException:
print(f"Current run threw error, skipping")
print("Current run threw error, skipping")
writer.writerow({"records": records, "features": features,
"parties": parties, "comment": "error", })

Expand Down
1 change: 0 additions & 1 deletion python/scripts/benchmark_find_z.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def main():
rho = RHO
Rt = group_samples_at_risk(event_times)
Dt = group_samples_on_event_time(event_times, event_happened)
K = 1
deaths_per_t = Aggregator.compute_deaths_per_t(event_times, event_happened)
eps = EPSILON
relevant_event_times = Aggregator._group_relevant_event_times(event_times)
Expand Down
289 changes: 0 additions & 289 deletions python/scripts/visualize.py

This file was deleted.

29 changes: 16 additions & 13 deletions python/tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@


@parser.value_converter
def strlist(l: str):
return l.split(",")
def str_list(some_string: str):
return some_string.split(",")


def compute_central_coefs(all_data_features, all_data_outcome):
def compute_central_coefficients(all_data_features, all_data_outcome):
central_model = CoxPHSurvivalAnalysis()
with warnings.catch_warnings():
warnings.filterwarnings('error', category=LinAlgWarning)
Expand All @@ -53,7 +53,7 @@ def compute_central_coefs(all_data_features, all_data_outcome):
return coef_dict


def compare_central_coefs(decentralized: dict[str, float], centralized: dict[str, float]) -> \
def compare_central_coefficients(decentralized: dict[str, float], centralized: dict[str, float]) -> \
dict[str, float]:
"""
Compares coefficients to results from centralized computation.
Expand Down Expand Up @@ -95,8 +95,8 @@ def compare_central_coefs(decentralized: dict[str, float], centralized: dict[str
class IntegrationTest(ABC):

def run(self, local_data, all_data, event_times_column, event_happened_column, *,
pythonnodes: strlist = ("pythonnode1:7777", "pythonnode2:7777"),
javanodes: strlist = ("javanode1:80", "javanode2:80", "javanode-outcome:80"),
pythonnodes: str_list = ("pythonnode1:7777", "pythonnode2:7777"),
javanodes: str_list = ("javanode1:80", "javanode2:80", "javanode-outcome:80"),
total_num_iterations=None):
"""
Run an integration test
Expand Down Expand Up @@ -131,11 +131,14 @@ def run(self, local_data, all_data, event_times_column, event_happened_column, *
end_time = datetime.now()
preparation_runtime = end_time - start_time
start_time = datetime.now()
results = self.run_integration_test(all_data_features, all_data_outcome, node_manager,
self.run_integration_test(all_data_features, all_data_outcome, node_manager,
check_correct)
end_time = datetime.now()
runtime = end_time - start_time

print(f"Finished test. Preparation took {preparation_runtime} seconds\n"
f"Convergence took {runtime} seconds")

@staticmethod
def run_integration_test(all_data_features, all_data_outcome, node_manager,
check_correct=True) -> (float, float):
Expand All @@ -157,24 +160,24 @@ def run_integration_test(all_data_features, all_data_outcome, node_manager,
)
try:
# Doing central one first to see if it succeeds
target_coefs = compute_central_coefs(all_data_features, all_data_outcome)
target_coefs = compute_central_coefficients(all_data_features, all_data_outcome)

node_manager.reset()
node_manager.fit()
coefs = node_manager.coefs
print(f"Betas: {coefs}")
print(f"Baseline hazard ratio {node_manager.baseline_hazard}")

comparison_metrics = compare_central_coefs(coefs, target_coefs)
comparison_metrics = compare_central_coefficients(coefs, target_coefs)
comparison_metrics["comment"] = "success"

print(f"Benchmark output: {json.dumps(comparison_metrics)}")

if check_correct:
for key, value in target_coefs.items():
np.testing.assert_almost_equal(value, coefs[key], decimal=DECIMAL_PRECISION)
print(f"Central and decentralized models are equal.")
except (LinAlgWarning, LinAlgError) as e:
print("Central and decentralized models are equal.")
except (LinAlgWarning, LinAlgError):
output = {"mse": None, "sad": None, "mad": None, "comment": "unsolvable"}
print(f"Benchmark output: {json.dumps(output)}")

Expand Down Expand Up @@ -223,8 +226,8 @@ def run_integration_test(all_data_features, all_data_outcome, node_manager,
central_c_index, _, _, _, _ = concordance_index_censored(event_indicator, event_time,
central_predictions)

target_coefs = compute_central_coefs(all_data_features_train, all_data_outcome_train)
comparison_metrics = compare_central_coefs(coefs, target_coefs)
target_coefs = compute_central_coefficients(all_data_features_train, all_data_outcome_train)
comparison_metrics = compare_central_coefficients(coefs, target_coefs)
comparison_metrics.update({"c_index_verticox": c_index, "c_index_central": central_c_index})
comparison_metrics["comment"] = "success"
print(f"Benchmark output: {json.dumps(comparison_metrics)}")
Expand Down
Loading