@@ -554,7 +554,7 @@ static void recv_done(struct ib_cq *cq, struct ib_wc *wc)
554
554
case SMB_DIRECT_MSG_DATA_TRANSFER : {
555
555
struct smb_direct_data_transfer * data_transfer =
556
556
(struct smb_direct_data_transfer * )recvmsg -> packet ;
557
- unsigned int data_length ;
557
+ u32 remaining_data_length , data_offset , data_length ;
558
558
int avail_recvmsg_count , receive_credits ;
559
559
560
560
if (wc -> byte_len <
@@ -564,15 +564,25 @@ static void recv_done(struct ib_cq *cq, struct ib_wc *wc)
564
564
return ;
565
565
}
566
566
567
+ remaining_data_length = le32_to_cpu (data_transfer -> remaining_data_length );
567
568
data_length = le32_to_cpu (data_transfer -> data_length );
568
- if (data_length ) {
569
- if (wc -> byte_len < sizeof (struct smb_direct_data_transfer ) +
570
- (u64 )data_length ) {
571
- put_recvmsg (t , recvmsg );
572
- smb_direct_disconnect_rdma_connection (t );
573
- return ;
574
- }
569
+ data_offset = le32_to_cpu (data_transfer -> data_offset );
570
+ if (wc -> byte_len < data_offset ||
571
+ wc -> byte_len < (u64 )data_offset + data_length ) {
572
+ put_recvmsg (t , recvmsg );
573
+ smb_direct_disconnect_rdma_connection (t );
574
+ return ;
575
+ }
576
+ if (remaining_data_length > t -> max_fragmented_recv_size ||
577
+ data_length > t -> max_fragmented_recv_size ||
578
+ (u64 )remaining_data_length + (u64 )data_length >
579
+ (u64 )t -> max_fragmented_recv_size ) {
580
+ put_recvmsg (t , recvmsg );
581
+ smb_direct_disconnect_rdma_connection (t );
582
+ return ;
583
+ }
575
584
585
+ if (data_length ) {
576
586
if (t -> full_packet_received )
577
587
recvmsg -> first_segment = true;
578
588
@@ -1209,78 +1219,130 @@ static int smb_direct_writev(struct ksmbd_transport *t,
1209
1219
bool need_invalidate , unsigned int remote_key )
1210
1220
{
1211
1221
struct smb_direct_transport * st = smb_trans_direct_transfort (t );
1212
- int remaining_data_length ;
1213
- int start , i , j ;
1214
- int max_iov_size = st -> max_send_size -
1222
+ size_t remaining_data_length ;
1223
+ size_t iov_idx ;
1224
+ size_t iov_ofs ;
1225
+ size_t max_iov_size = st -> max_send_size -
1215
1226
sizeof (struct smb_direct_data_transfer );
1216
1227
int ret ;
1217
- struct kvec vec ;
1218
1228
struct smb_direct_send_ctx send_ctx ;
1229
+ int error = 0 ;
1219
1230
1220
1231
if (st -> status != SMB_DIRECT_CS_CONNECTED )
1221
1232
return - ENOTCONN ;
1222
1233
1223
1234
//FIXME: skip RFC1002 header..
1235
+ if (WARN_ON_ONCE (niovs <= 1 || iov [0 ].iov_len != 4 ))
1236
+ return - EINVAL ;
1224
1237
buflen -= 4 ;
1238
+ iov_idx = 1 ;
1239
+ iov_ofs = 0 ;
1225
1240
1226
1241
remaining_data_length = buflen ;
1227
1242
ksmbd_debug (RDMA , "Sending smb (RDMA): smb_len=%u\n" , buflen );
1228
1243
1229
1244
smb_direct_send_ctx_init (st , & send_ctx , need_invalidate , remote_key );
1230
- start = i = 1 ;
1231
- buflen = 0 ;
1232
- while (true) {
1233
- buflen += iov [i ].iov_len ;
1234
- if (buflen > max_iov_size ) {
1235
- if (i > start ) {
1236
- remaining_data_length -=
1237
- (buflen - iov [i ].iov_len );
1238
- ret = smb_direct_post_send_data (st , & send_ctx ,
1239
- & iov [start ], i - start ,
1240
- remaining_data_length );
1241
- if (ret )
1245
+ while (remaining_data_length ) {
1246
+ struct kvec vecs [SMB_DIRECT_MAX_SEND_SGES - 1 ]; /* minus smbdirect hdr */
1247
+ size_t possible_bytes = max_iov_size ;
1248
+ size_t possible_vecs ;
1249
+ size_t bytes = 0 ;
1250
+ size_t nvecs = 0 ;
1251
+
1252
+ /*
1253
+ * For the last message remaining_data_length should be
1254
+ * have been 0 already!
1255
+ */
1256
+ if (WARN_ON_ONCE (iov_idx >= niovs )) {
1257
+ error = - EINVAL ;
1258
+ goto done ;
1259
+ }
1260
+
1261
+ /*
1262
+ * We have 2 factors which limit the arguments we pass
1263
+ * to smb_direct_post_send_data():
1264
+ *
1265
+ * 1. The number of supported sges for the send,
1266
+ * while one is reserved for the smbdirect header.
1267
+ * And we currently need one SGE per page.
1268
+ * 2. The number of negotiated payload bytes per send.
1269
+ */
1270
+ possible_vecs = min_t (size_t , ARRAY_SIZE (vecs ), niovs - iov_idx );
1271
+
1272
+ while (iov_idx < niovs && possible_vecs && possible_bytes ) {
1273
+ struct kvec * v = & vecs [nvecs ];
1274
+ int page_count ;
1275
+
1276
+ v -> iov_base = ((u8 * )iov [iov_idx ].iov_base ) + iov_ofs ;
1277
+ v -> iov_len = min_t (size_t ,
1278
+ iov [iov_idx ].iov_len - iov_ofs ,
1279
+ possible_bytes );
1280
+ page_count = get_buf_page_count (v -> iov_base , v -> iov_len );
1281
+ if (page_count > possible_vecs ) {
1282
+ /*
1283
+ * If the number of pages in the buffer
1284
+ * is to much (because we currently require
1285
+ * one SGE per page), we need to limit the
1286
+ * length.
1287
+ *
1288
+ * We know possible_vecs is at least 1,
1289
+ * so we always keep the first page.
1290
+ *
1291
+ * We need to calculate the number extra
1292
+ * pages (epages) we can also keep.
1293
+ *
1294
+ * We calculate the number of bytes in the
1295
+ * first page (fplen), this should never be
1296
+ * larger than v->iov_len because page_count is
1297
+ * at least 2, but adding a limitation feels
1298
+ * better.
1299
+ *
1300
+ * Then we calculate the number of bytes (elen)
1301
+ * we can keep for the extra pages.
1302
+ */
1303
+ size_t epages = possible_vecs - 1 ;
1304
+ size_t fpofs = offset_in_page (v -> iov_base );
1305
+ size_t fplen = min_t (size_t , PAGE_SIZE - fpofs , v -> iov_len );
1306
+ size_t elen = min_t (size_t , v -> iov_len - fplen , epages * PAGE_SIZE );
1307
+
1308
+ v -> iov_len = fplen + elen ;
1309
+ page_count = get_buf_page_count (v -> iov_base , v -> iov_len );
1310
+ if (WARN_ON_ONCE (page_count > possible_vecs )) {
1311
+ /*
1312
+ * Something went wrong in the above
1313
+ * logic...
1314
+ */
1315
+ error = - EINVAL ;
1242
1316
goto done ;
1243
- } else {
1244
- /* iov[start] is too big, break it */
1245
- int nvec = (buflen + max_iov_size - 1 ) /
1246
- max_iov_size ;
1247
-
1248
- for (j = 0 ; j < nvec ; j ++ ) {
1249
- vec .iov_base =
1250
- (char * )iov [start ].iov_base +
1251
- j * max_iov_size ;
1252
- vec .iov_len =
1253
- min_t (int , max_iov_size ,
1254
- buflen - max_iov_size * j );
1255
- remaining_data_length -= vec .iov_len ;
1256
- ret = smb_direct_post_send_data (st , & send_ctx , & vec , 1 ,
1257
- remaining_data_length );
1258
- if (ret )
1259
- goto done ;
1260
1317
}
1261
- i ++ ;
1262
- if (i == niovs )
1263
- break ;
1264
1318
}
1265
- start = i ;
1266
- buflen = 0 ;
1267
- } else {
1268
- i ++ ;
1269
- if (i == niovs ) {
1270
- /* send out all remaining vecs */
1271
- remaining_data_length -= buflen ;
1272
- ret = smb_direct_post_send_data (st , & send_ctx ,
1273
- & iov [start ], i - start ,
1274
- remaining_data_length );
1275
- if (ret )
1276
- goto done ;
1277
- break ;
1319
+ possible_vecs -= page_count ;
1320
+ nvecs += 1 ;
1321
+ possible_bytes -= v -> iov_len ;
1322
+ bytes += v -> iov_len ;
1323
+
1324
+ iov_ofs += v -> iov_len ;
1325
+ if (iov_ofs >= iov [iov_idx ].iov_len ) {
1326
+ iov_idx += 1 ;
1327
+ iov_ofs = 0 ;
1278
1328
}
1279
1329
}
1330
+
1331
+ remaining_data_length -= bytes ;
1332
+
1333
+ ret = smb_direct_post_send_data (st , & send_ctx ,
1334
+ vecs , nvecs ,
1335
+ remaining_data_length );
1336
+ if (unlikely (ret )) {
1337
+ error = ret ;
1338
+ goto done ;
1339
+ }
1280
1340
}
1281
1341
1282
1342
done :
1283
1343
ret = smb_direct_flush_send_list (st , & send_ctx , true);
1344
+ if (unlikely (!ret && error ))
1345
+ ret = error ;
1284
1346
1285
1347
/*
1286
1348
* As an optimization, we don't wait for individual I/O to finish
@@ -1744,6 +1806,11 @@ static int smb_direct_init_params(struct smb_direct_transport *t,
1744
1806
return - EINVAL ;
1745
1807
}
1746
1808
1809
+ if (device -> attrs .max_send_sge < SMB_DIRECT_MAX_SEND_SGES ) {
1810
+ pr_err ("warning: device max_send_sge = %d too small\n" ,
1811
+ device -> attrs .max_send_sge );
1812
+ return - EINVAL ;
1813
+ }
1747
1814
if (device -> attrs .max_recv_sge < SMB_DIRECT_MAX_RECV_SGES ) {
1748
1815
pr_err ("warning: device max_recv_sge = %d too small\n" ,
1749
1816
device -> attrs .max_recv_sge );
@@ -1767,7 +1834,7 @@ static int smb_direct_init_params(struct smb_direct_transport *t,
1767
1834
1768
1835
cap -> max_send_wr = max_send_wrs ;
1769
1836
cap -> max_recv_wr = t -> recv_credit_max ;
1770
- cap -> max_send_sge = max_sge_per_wr ;
1837
+ cap -> max_send_sge = SMB_DIRECT_MAX_SEND_SGES ;
1771
1838
cap -> max_recv_sge = SMB_DIRECT_MAX_RECV_SGES ;
1772
1839
cap -> max_inline_data = 0 ;
1773
1840
cap -> max_rdma_ctxs = t -> max_rw_credits ;
0 commit comments