-
-
Notifications
You must be signed in to change notification settings - Fork 568
Description
Hello,
I’m encountering a problem implementing and visualizing restricted cubic splines (RCS) with the lifelines library in Python. My goal is to reproduce the RCS visualization provided by the R package rms, but I haven’t been able to get a correct RCS plot.
I’ve included my current lifelines code below. When you have time, could you review it and suggest what I might be missing or how to improve it?
Thank you in advance!
temp2 = pd.read_csv("./cancer.csv")
---------- 0) Package load ----------
import numpy as np
import pandas as pd
from lifelines import CoxPHFitter
import matplotlib.pyplot as plt
---------- 1) RCS basis function (as is) ----------
def rcs_basis(x, knots, include_intercept=True):
x = np.asarray(x)
K = len(knots)
if K < 3:
raise ValueError("least 3 knot need")
min_knot, max_knot = knots[0], knots[-1]
def basis_term(x, knot):
lam = (max_knot - knot) / (max_knot - min_knot)
return (
np.maximum(x - knot, 0) ** 3
- lam * np.maximum(x - min_knot, 0) ** 3
- (1 - lam) * np.maximum(x - max_knot, 0) ** 3
)
basis = [x] if include_intercept else [] # Cox has no intercept; add only the linear term 'x'
for k in knots[1:-1]: # Build basis only for internal knots
basis.append(basis_term(x, k))
return np.column_stack(basis)
---------- 2) Build RCS ----------
hr_col = 'age'
knots = np.percentile(temp2[hr_col].dropna(), [10, 50, 90]) # [10, 50, 90], [5, 25, 50, 75, 95], [5, 35, 65, 95]
X_rcs = rcs_basis(temp2[hr_col], knots=knots, include_intercept=True)
rcs_cols = [f"hr_rcs_{i}" for i in range(X_rcs.shape[1])]
df_rcs = pd.DataFrame(X_rcs, columns=rcs_cols, index=temp2.index)
---------- 3) Prepare covariates (★ key added part) ----------
# Specify numeric/categorical covariates
num_covs = ["size"] # e.g., numeric
cat_covs = ["sex", "metastasis"] # e.g., categorical (e.g., sex=M/F, emergency=0/1)
# Impute missing values (simple: numeric=median, categorical=mode)
cov_raw = temp2[num_covs + cat_covs].copy()
for c in num_covs:
cov_raw[c] = cov_raw[c].fillna(cov_raw[c].median())
for c in cat_covs:
cov_raw[c] = cov_raw[c].fillna(cov_raw[c].mode().iloc[0])
# Dummy-encode (drop_first=True to avoid multicollinearity)
cov_df = pd.get_dummies(cov_raw, columns=cat_covs, drop_first=True)
---------- 4) Build training df_model ----------
df_model = pd.concat([df_rcs, cov_df], axis=1)
df_model["time"] = temp2["time"].to_numpy()
df_model["event"] = temp2["event"].to_numpy()
# lifelines errors on missing values → final safeguard
df_model = df_model.dropna(axis=0)
---------- 5) Fit ----------
cph = CoxPHFitter()
cph.fit(df_model, duration_col="time", event_col="event")
cph.print_summary()
---------- 6) Build prediction matrix for visualization (★ also key) ----------
# Vary HR only; fix covariates at "typical" values to draw the curve
hr_range = np.linspace(np.percentile(temp2[hr_col].dropna(), 1),
np.percentile(temp2[hr_col].dropna(), 99), 600)
X_vis_rcs = rcs_basis(hr_range, knots=knots, include_intercept=True)
df_hr_vis = pd.DataFrame(X_vis_rcs, columns=rcs_cols)
# Reference covariates (numeric=median; dummies=column median ≈ modal category)
cov_ref = cov_df.median(numeric_only=True)
df_cov_vis = pd.DataFrame(np.tile(cov_ref.values, (len(hr_range), 1)),
columns=cov_df.columns)
# Concatenate in the same column order as training
X_vis_full = pd.concat([df_hr_vis, df_cov_vis], axis=1)
---------- 7) Predict partial hazard & normalize by reference HR (readability) ----------
partial_hazard = cph.predict_partial_hazard(X_vis_full).values.reshape(-1)
# Normalize to 1 at the reference HR (e.g., overall median) (optional)
hr_ref = float(np.median(temp2[hr_col].dropna()))
X_ref_rcs = rcs_basis([hr_ref], knots=knots, include_intercept=True)
df_ref_hr = pd.DataFrame(X_ref_rcs, columns=rcs_cols)
df_ref_cov = pd.DataFrame([cov_ref.values], columns=cov_df.columns)
X_ref_full = pd.concat([df_ref_hr, df_ref_cov], axis=1)
ref_hazard = float(cph.predict_partial_hazard(X_ref_full).values)
partial_hazard_norm = partial_hazard / ref_hazard
---------- 8) Plot ----------
plt.plot(hr_range, partial_hazard_norm)
plt.axhline(1.0, ls='--', lw=1)
plt.xlabel("Heart Rate")
plt.ylabel("Relative Hazard (vs. HR = median)")
plt.title("Age (RCS) with Covariates")
plt.grid(True)
plt.tight_layout()
plt.show()