Skip to content

Restricted cubic spline (RCS) visualization issue in lifelines #1664

@kcw0331

Description

@kcw0331

Hello,
I’m encountering a problem implementing and visualizing restricted cubic splines (RCS) with the lifelines library in Python. My goal is to reproduce the RCS visualization provided by the R package rms, but I haven’t been able to get a correct RCS plot.
I’ve included my current lifelines code below. When you have time, could you review it and suggest what I might be missing or how to improve it?
Thank you in advance!

cancer.csv

temp2 = pd.read_csv("./cancer.csv")

---------- 0) Package load ----------

import numpy as np
import pandas as pd
from lifelines import CoxPHFitter
import matplotlib.pyplot as plt

---------- 1) RCS basis function (as is) ----------

def rcs_basis(x, knots, include_intercept=True):
    x = np.asarray(x)
    K = len(knots)
    if K < 3:
        raise ValueError("least 3 knot need")
    min_knot, max_knot = knots[0], knots[-1]

    def basis_term(x, knot):
        lam = (max_knot - knot) / (max_knot - min_knot)
        return (
            np.maximum(x - knot, 0) ** 3
            - lam * np.maximum(x - min_knot, 0) ** 3
            - (1 - lam) * np.maximum(x - max_knot, 0) ** 3
        )

    basis = [x] if include_intercept else []     # Cox has no intercept; add only the linear term 'x'
    for k in knots[1:-1]:                                 # Build basis only for internal knots
        basis.append(basis_term(x, k))
    return np.column_stack(basis)

---------- 2) Build RCS ----------

hr_col = 'age'  
knots = np.percentile(temp2[hr_col].dropna(), [10, 50, 90]) # [10, 50, 90], [5, 25, 50, 75, 95], [5, 35, 65, 95]

X_rcs = rcs_basis(temp2[hr_col], knots=knots, include_intercept=True)
rcs_cols = [f"hr_rcs_{i}" for i in range(X_rcs.shape[1])]
df_rcs = pd.DataFrame(X_rcs, columns=rcs_cols, index=temp2.index)

---------- 3) Prepare covariates (★ key added part) ----------

# Specify numeric/categorical covariates
num_covs = ["size"]                        # e.g., numeric
cat_covs = ["sex", "metastasis"]                 # e.g., categorical (e.g., sex=M/F, emergency=0/1)

# Impute missing values (simple: numeric=median, categorical=mode)
cov_raw = temp2[num_covs + cat_covs].copy()
for c in num_covs:
    cov_raw[c] = cov_raw[c].fillna(cov_raw[c].median())
for c in cat_covs:
    cov_raw[c] = cov_raw[c].fillna(cov_raw[c].mode().iloc[0])

# Dummy-encode (drop_first=True to avoid multicollinearity)
cov_df = pd.get_dummies(cov_raw, columns=cat_covs, drop_first=True)

---------- 4) Build training df_model ----------

df_model = pd.concat([df_rcs, cov_df], axis=1)
df_model["time"]  = temp2["time"].to_numpy()
df_model["event"] = temp2["event"].to_numpy()

# lifelines errors on missing values → final safeguard
df_model = df_model.dropna(axis=0)

---------- 5) Fit ----------

cph = CoxPHFitter()
cph.fit(df_model, duration_col="time", event_col="event")
cph.print_summary()

---------- 6) Build prediction matrix for visualization (★ also key) ----------

# Vary HR only; fix covariates at "typical" values to draw the curve
hr_range = np.linspace(np.percentile(temp2[hr_col].dropna(), 1),
                       np.percentile(temp2[hr_col].dropna(), 99), 600)

X_vis_rcs = rcs_basis(hr_range, knots=knots, include_intercept=True)
df_hr_vis = pd.DataFrame(X_vis_rcs, columns=rcs_cols)

# Reference covariates (numeric=median; dummies=column median ≈ modal category)
cov_ref = cov_df.median(numeric_only=True)
df_cov_vis = pd.DataFrame(np.tile(cov_ref.values, (len(hr_range), 1)),
                          columns=cov_df.columns)

# Concatenate in the same column order as training
X_vis_full = pd.concat([df_hr_vis, df_cov_vis], axis=1)

---------- 7) Predict partial hazard & normalize by reference HR (readability) ----------

partial_hazard = cph.predict_partial_hazard(X_vis_full).values.reshape(-1)

# Normalize to 1 at the reference HR (e.g., overall median) (optional)
hr_ref = float(np.median(temp2[hr_col].dropna()))
X_ref_rcs = rcs_basis([hr_ref], knots=knots, include_intercept=True)
df_ref_hr = pd.DataFrame(X_ref_rcs, columns=rcs_cols)
df_ref_cov = pd.DataFrame([cov_ref.values], columns=cov_df.columns)
X_ref_full = pd.concat([df_ref_hr, df_ref_cov], axis=1)
ref_hazard = float(cph.predict_partial_hazard(X_ref_full).values)
partial_hazard_norm = partial_hazard / ref_hazard

---------- 8) Plot ----------

plt.plot(hr_range, partial_hazard_norm)
plt.axhline(1.0, ls='--', lw=1)
plt.xlabel("Heart Rate")
plt.ylabel("Relative Hazard (vs. HR = median)")
plt.title("Age (RCS) with Covariates")
plt.grid(True)
plt.tight_layout()
plt.show()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions