Skip to content

Commit 3182099

Browse files
committed
v6ver481v3_rm-dup-rowv2
1 parent fd4cc86 commit 3182099

File tree

1 file changed

+19
-11
lines changed

1 file changed

+19
-11
lines changed

federated_cvdm_training_poc/partial_risk_prediction.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,13 @@ def log_vantage_dists(info_fn):
5050

5151
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
5252

53-
def drop_duplicated_header_rows(df: pd.DataFrame) -> pd.DataFrame:
53+
import pandas as pd
54+
55+
def drop_duplicated_header_rows_strict(df: pd.DataFrame) -> pd.DataFrame:
5456
"""
55-
Drop rows that are duplicated CSV header lines inside the file.
56-
Typical symptom: a row contains literal strings equal to column names
57-
(e.g., 'GENDER' appears as a value under column GENDER).
57+
Strictly drop only rows that are duplicated CSV headers.
58+
A duplicated header row typically has values equal to the column names.
59+
This version is designed to avoid removing any legitimate data rows.
5860
"""
5961
if df is None or df.empty:
6062
return df
@@ -63,21 +65,27 @@ def drop_duplicated_header_rows(df: pd.DataFrame) -> pd.DataFrame:
6365
if not cols:
6466
return df
6567

66-
# Count how many columns match their own column-name in each row
67-
hits = np.zeros(len(df), dtype=int)
68-
for c in cols:
69-
hits += df[c].astype(str).str.strip().eq(str(c).strip()).to_numpy(dtype=int)
68+
# Compare each cell to its column name (string-wise)
69+
colname_row = pd.Series({c: str(c).strip() for c in cols})
70+
mask_full = df[cols].astype(str).apply(lambda s: s.str.strip()).eq(colname_row, axis=1).all(axis=1)
7071

71-
# Mark as header-like if >= half of columns match their own names
72-
threshold = max(2, int(0.5 * len(cols)))
73-
mask = hits >= threshold
72+
# Extra-safe: also catch partial header rows by checking key outcome columns if present
73+
key_cols = [c for c in ["FSTAT", "LENFOL"] if c in df.columns]
74+
if key_cols:
75+
mask_key = df[key_cols].astype(str).apply(lambda s: s.str.strip()).eq(
76+
pd.Series({c: c for c in key_cols}), axis=1
77+
).all(axis=1)
78+
mask = mask_full | mask_key
79+
else:
80+
mask = mask_full
7481

7582
if mask.any():
7683
df = df.loc[~mask].copy()
7784
df.reset_index(drop=True, inplace=True)
7885

7986
return df
8087

88+
8189
@data(1)
8290
@algorithm_client
8391
def partial_risk_prediction(

0 commit comments

Comments
 (0)